Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
279 additions
and
398 deletions
+279
-398
maint/host_checks/08_device_id_mismatch.py
maint/host_checks/08_device_id_mismatch.py
+2
-2
maint/host_checks/09_null_data_pointer.py
maint/host_checks/09_null_data_pointer.py
+1
-0
maint/host_checks/10_scalar_type_mismatch.py
maint/host_checks/10_scalar_type_mismatch.py
+2
-2
maint/host_checks/common.py
maint/host_checks/common.py
+4
-13
maint/precision/compare_ops.py
maint/precision/compare_ops.py
+27
-43
maint/scripts/ci_performance.py
maint/scripts/ci_performance.py
+18
-25
maint/scripts/performance.py
maint/scripts/performance.py
+22
-18
pyproject.toml
pyproject.toml
+10
-5
requirements-lint.txt
requirements-lint.txt
+0
-1
testing/conftest.py
testing/conftest.py
+2
-5
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+19
-35
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
+52
-82
testing/python/amd/test_tilelang_test_amd.py
testing/python/amd/test_tilelang_test_amd.py
+6
-9
testing/python/analysis/test_tilelang_fragment_loop_checker.py
...ng/python/analysis/test_tilelang_fragment_loop_checker.py
+12
-24
testing/python/analysis/test_tilelang_nested_loop_checker.py
testing/python/analysis/test_tilelang_nested_loop_checker.py
+56
-64
testing/python/autotune/test_tilelang_autotune.py
testing/python/autotune/test_tilelang_autotune.py
+17
-13
testing/python/autotune/test_tilelang_autotune_with_inputs.py
...ing/python/autotune/test_tilelang_autotune_with_inputs.py
+11
-27
testing/python/cache/test_tilelang_cache_matmul.py
testing/python/cache/test_tilelang_cache_matmul.py
+4
-3
testing/python/carver/test_tilelang_carver_cuda_driver_properties.py
...hon/carver/test_tilelang_carver_cuda_driver_properties.py
+5
-11
testing/python/carver/test_tilelang_carver_generate_hints.py
testing/python/carver/test_tilelang_carver_generate_hints.py
+9
-16
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
maint/host_checks/08_device_id_mismatch.py
View file @
29051439
"""Reproduce: device_id mismatch (requires >=2 CUDA devices).
"""
"""Reproduce: device_id mismatch (requires >=2 CUDA devices).
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/09_null_data_pointer.py
View file @
29051439
...
...
@@ -7,6 +7,7 @@ or a host-side non-NULL pointer check.
Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script
demonstrates passing None, which still reproduces the intended class of failure.
"""
import
torch
from
common
import
build_matmul_kernel
...
...
maint/host_checks/10_scalar_type_mismatch.py
View file @
29051439
"""Reproduce: scalar parameter type mismatch (int/bool).
"""
"""Reproduce: scalar parameter type mismatch (int/bool).
"""
from
common
import
build_scalar_check_kernel
...
...
maint/host_checks/common.py
View file @
29051439
...
...
@@ -3,20 +3,12 @@ import tilelang.language as T
import
torch
def
make_matmul_prim
(
M
,
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
def
make_matmul_prim
(
M
,
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"):
def
build_scalar_check_kernel
(
target
=
"cuda"
):
@
T
.
prim_func
def
scalar_check
(
x
:
T
.
int32
,
flag
:
T
.
bool
()):
T
.
evaluate
(
0
)
...
...
maint/precision/compare_ops.py
View file @
29051439
...
...
@@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = {
6
:
"sqrt"
,
7
:
"tanh"
,
8
:
"rsqrt"
,
9
:
"inv_sqrt"
9
:
"inv_sqrt"
,
}
# Block sizes for kernels
...
...
@@ -49,8 +49,7 @@ TILELANG_THREADS = 128
def
parse_arguments
()
->
argparse
.
Namespace
:
"""Parse command line arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
"Precision comparison tool for various CUDA implementations"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Precision comparison tool for various CUDA implementations"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1000000
,
help
=
"Number of elements to test"
)
parser
.
add_argument
(
"--low"
,
type
=
float
,
default
=-
4.0
,
help
=
"Lower bound for random values"
)
parser
.
add_argument
(
"--high"
,
type
=
float
,
default
=
4.0
,
help
=
"Upper bound for random values"
)
...
...
@@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module:
return
load
(
name
=
"cuda_ops"
,
sources
=
[
"cuda_ops.cu"
],
extra_cuda_cflags
=
[]
# No fast_math flags
extra_cuda_cflags
=
[]
,
# No fast_math flags
)
...
...
@@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S
@
triton
.
jit
def
triton_libdevice_unary_kernel
(
x_ptr
,
out_ptr
,
n_elements
,
op_id
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
):
def
triton_libdevice_unary_kernel
(
x_ptr
,
out_ptr
,
n_elements
,
op_id
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
):
"""LibDevice Triton kernel for unary operations."""
pid
=
tl
.
program_id
(
0
)
block_start
=
pid
*
BLOCK_SIZE
...
...
@@ -188,13 +186,10 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool =
@
T
.
prim_func
def
tilelang_unary_kernel
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
TILELANG_BLOCK_M
,
TILELANG_BLOCK_N
):
row
=
by
*
TILELANG_BLOCK_M
+
i
col
=
bx
*
TILELANG_BLOCK_N
+
j
...
...
@@ -229,14 +224,11 @@ def make_tilelang_binary_kernel(M: int, N: int):
@
T
.
prim_func
def
tilelang_binary_kernel
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
C
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
C
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
TILELANG_BLOCK_M
,
TILELANG_BLOCK_N
):
row
=
by
*
TILELANG_BLOCK_M
+
i
col
=
bx
*
TILELANG_BLOCK_N
+
j
...
...
@@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int):
return
tilelang_binary_kernel
def
tilelang_op
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fastmath
:
bool
=
False
)
->
torch
.
Tensor
:
def
tilelang_op
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fastmath
:
bool
=
False
)
->
torch
.
Tensor
:
"""TileLang operation interface."""
assert
x
.
is_cuda
...
...
@@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
use_fastmath
,
})
},
)
out
=
kernel
(
x
,
y
)
else
:
# Unary operation
kernel_func
=
make_tilelang_unary_kernel
(
M
,
N
,
op_id
,
use_fastmath
)
...
...
@@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor,
target
=
"cuda"
,
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
use_fastmath
,
})
},
)
out
=
kernel
(
x
)
# Restore original shape
...
...
@@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
"""Standard Triton operation interface."""
assert
x
.
is_cuda
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
'
BLOCK_SIZE
'
]
-
1
)
//
meta
[
'
BLOCK_SIZE
'
],)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
"
BLOCK_SIZE
"
]
-
1
)
//
meta
[
"
BLOCK_SIZE
"
],)
if
op_id
==
0
:
# Division - binary operation
assert
y
is
not
None
,
"Division operation requires second operand"
...
...
@@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
return
out
def
triton_libdevice_op
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
triton_libdevice_op
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""LibDevice Triton operation interface."""
assert
x
.
is_cuda
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
'
BLOCK_SIZE
'
]
-
1
)
//
meta
[
'
BLOCK_SIZE
'
],)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
"
BLOCK_SIZE
"
]
-
1
)
//
meta
[
"
BLOCK_SIZE
"
],)
if
op_id
==
0
:
# Division - binary operation
assert
y
is
not
None
,
"Division operation requires second operand"
...
...
@@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor,
return
out
def
get_pytorch_reference
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
def
get_pytorch_reference
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Get PyTorch reference implementation for the given operation."""
if
op_id
==
0
:
assert
y
is
not
None
,
"Division requires second operand"
...
...
@@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T
abs_err
=
(
output_double
-
reference_double
).
abs
()
rel_err
=
abs_err
/
(
reference_double
.
abs
().
clamp_min
(
1e-30
))
print
(
f
"
{
tag
:
<
32
}
max abs:
{
abs_err
.
max
():.
3
e
}
, mean abs:
{
abs_err
.
mean
():.
3
e
}
, "
f
"max rel:
{
rel_err
.
max
():.
3
e
}
, mean rel:
{
rel_err
.
mean
():.
3
e
}
"
)
print
(
f
"
{
tag
:
<
32
}
max abs:
{
abs_err
.
max
():.
3
e
}
, mean abs:
{
abs_err
.
mean
():.
3
e
}
, "
f
"max rel:
{
rel_err
.
max
():.
3
e
}
, mean rel:
{
rel_err
.
mean
():.
3
e
}
"
)
# Precision comparison function
...
...
@@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
results
[
name
]
=
None
# Print comparison header
print
(
f
"
{
'Implementation'
:
<
32
}
{
'Max Abs Error'
:
<
19
}
{
'Mean Abs Error'
:
<
20
}
{
'Max Rel Error'
:
<
19
}
{
'Mean Rel Error'
}
"
)
print
(
f
"
{
'Implementation'
:
<
32
}
{
'Max Abs Error'
:
<
19
}
{
'Mean Abs Error'
:
<
20
}
{
'Max Rel Error'
:
<
19
}
{
'Mean Rel Error'
}
"
)
print
(
"-"
*
90
)
# Compare all implementations against double precision reference
...
...
@@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
summarize_error
(
tag
,
output
,
ref_double
)
def
generate_test_data
(
op_id
:
int
,
n
:
int
,
device
:
torch
.
device
,
low
:
float
,
high
:
float
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
def
generate_test_data
(
op_id
:
int
,
n
:
int
,
device
:
torch
.
device
,
low
:
float
,
high
:
float
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Generate appropriate test data for each operation."""
if
op_id
==
0
:
# Division
x
=
torch
.
empty
(
n
,
device
=
device
).
uniform_
(
low
,
high
)
...
...
@@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float,
def
main
()
->
None
:
"""Main execution function."""
print
(
"Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang"
)
print
(
"Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang"
)
print
(
"="
*
90
)
for
op_id
in
range
(
len
(
OP_NAMES
)):
...
...
maint/scripts/ci_performance.py
View file @
29051439
...
...
@@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1"
def
parse_output
(
output
):
data
=
{}
for
line
in
output
.
split
(
'
\n
'
):
for
line
in
output
.
split
(
"
\n
"
):
line
=
line
.
strip
()
if
line
.
startswith
(
'
Latency:
'
):
match
=
re
.
search
(
r
'
Latency: ([\d.]+)
'
,
line
)
data
[
'
latency
'
]
=
match
.
group
(
1
)
if
match
else
'
N/A
'
elif
line
.
startswith
(
'
TFlops:
'
):
match
=
re
.
search
(
r
'
TFlops: ([\d.]+)
'
,
line
)
data
[
'
best_tflops
'
]
=
match
.
group
(
1
)
if
match
else
'
N/A
'
elif
line
.
startswith
(
'
Config:
'
):
data
[
'
config
'
]
=
line
.
split
(
'
Config:
'
)[
-
1
]
elif
line
.
startswith
(
'
Reference TFlops:
'
):
match
=
re
.
search
(
r
'
Reference TFlops: ([\d.]+)
'
,
line
)
data
[
'
ref_tflops
'
]
=
match
.
group
(
1
)
if
match
else
'
N/A
'
if
line
.
startswith
(
"
Latency:
"
):
match
=
re
.
search
(
r
"
Latency: ([\d.]+)
"
,
line
)
data
[
"
latency
"
]
=
match
.
group
(
1
)
if
match
else
"
N/A
"
elif
line
.
startswith
(
"
TFlops:
"
):
match
=
re
.
search
(
r
"
TFlops: ([\d.]+)
"
,
line
)
data
[
"
best_tflops
"
]
=
match
.
group
(
1
)
if
match
else
"
N/A
"
elif
line
.
startswith
(
"
Config:
"
):
data
[
"
config
"
]
=
line
.
split
(
"
Config:
"
)[
-
1
]
elif
line
.
startswith
(
"
Reference TFlops:
"
):
match
=
re
.
search
(
r
"
Reference TFlops: ([\d.]+)
"
,
line
)
data
[
"
ref_tflops
"
]
=
match
.
group
(
1
)
if
match
else
"
N/A
"
return
data
output_v1
=
subprocess
.
run
([
'./tl/bin/python'
,
'./maint/scripts/performance.py'
],
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
output_v1
=
subprocess
.
run
([
"./tl/bin/python"
,
"./maint/scripts/performance.py"
],
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
data_v1
=
parse_output
(
output_v1
)
output_v2
=
subprocess
.
run
([
'./tll/bin/python'
,
'./maint/scripts/performance.py'
],
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
output_v2
=
subprocess
.
run
([
"./tll/bin/python"
,
"./maint/scripts/performance.py"
],
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
data_v2
=
parse_output
(
output_v2
)
table
=
[[
"original"
,
data_v1
[
'latency'
],
data_v1
[
'best_tflops'
],
data_v1
[
'ref_tflops'
],
data_v1
[
'config'
]
],
[
"current"
,
data_v2
[
'latency'
],
data_v2
[
'best_tflops'
],
data_v2
[
'ref_tflops'
],
data_v2
[
'config'
]
]]
table
=
[
[
"original"
,
data_v1
[
"latency"
],
data_v1
[
"best_tflops"
],
data_v1
[
"ref_tflops"
],
data_v1
[
"config"
]],
[
"current"
,
data_v2
[
"latency"
],
data_v2
[
"best_tflops"
],
data_v2
[
"ref_tflops"
],
data_v2
[
"config"
]],
]
headers
=
[
"version"
,
"Best Latency (s)"
,
"Best TFlops"
,
"Reference TFlops"
,
"Best Config"
]
...
...
maint/scripts/performance.py
View file @
29051439
...
...
@@ -8,19 +8,20 @@ def ref_program(A, B):
def
get_configs
():
configs
=
[{
"block_M"
:
128
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
,
# keep param name for backward-compat
}]
configs
=
[
{
"block_M"
:
128
,
"block_N"
:
128
,
"block_K"
:
64
,
"num_stages"
:
2
,
"thread_num"
:
256
,
"enable_rasteration"
:
True
,
# keep param name for backward-compat
}
]
return
configs
def
run
(
M
,
N
,
K
):
def
kernel
(
block_M
=
None
,
block_N
=
None
,
...
...
@@ -34,12 +35,11 @@ def run(M, N, K):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -60,12 +60,16 @@ def run(M, N, K):
return
main
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
()).
set_compile_args
(
autotuner
=
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
())
.
set_compile_args
(
out_idx
=
[
-
1
],
target
=
"auto"
,
).
set_profile_args
(
ref_prog
=
ref_program
,)
)
.
set_profile_args
(
ref_prog
=
ref_program
,
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
...
pyproject.toml
View file @
29051439
...
...
@@ -122,10 +122,7 @@ tilelang = "tilelang"
"tilelang/3rdparty/composable_kernel/include"
=
"3rdparty/composable_kernel/include"
"tilelang/3rdparty/composable_kernel/library"
=
"3rdparty/composable_kernel/library"
[tool.yapf]
based_on_style
=
"yapf"
column_limit
=
100
indent_width
=
4
[tool.codespell]
ignore-words
=
"docs/spelling_wordlist.txt"
...
...
@@ -138,7 +135,7 @@ skip = [
[tool.ruff]
target-version
=
"py39"
line-length
=
1
0
0
line-length
=
1
4
0
output-format
=
"full"
exclude
=
[
...
...
@@ -146,6 +143,14 @@ exclude = [
"examples/deepseek_v32/inference"
,
]
[tool.ruff.format]
quote-style
=
"double"
indent-style
=
"space"
skip-magic-trailing-comma
=
false
line-ending
=
"auto"
docstring-code-format
=
false
docstring-code-line-length
=
"dynamic"
[tool.ruff.lint.per-file-ignores]
# Do not upgrade type hint in testing and examples.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
...
...
requirements-lint.txt
View file @
29051439
...
...
@@ -4,4 +4,3 @@ clang-format==21.1.2
clang-tidy==21.1.1
codespell[toml]==2.4.1
ruff==0.14.3
yapf==0.43.0
testing/conftest.py
View file @
29051439
...
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings"
,
"error"
,
}
if
(
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
):
if
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
:
terminalreporter
.
write_sep
(
"!"
,
(
f
"Error: No tests were collected. "
f
"
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
(
f
"Error: No tests were collected.
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
29051439
...
...
@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
import
tilelang.language
as
T
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
from
tilelang.intrinsics.mfma_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
MatrixCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
...
...
@@ -22,7 +23,6 @@ def tl_matmul(
b_transposed
=
True
,
k_pack
=
1
,
):
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
...
...
@@ -78,12 +78,11 @@ def tl_matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
...
@@ -91,10 +90,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
...
...
@@ -102,7 +103,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
0
):
# Load A into shared memory
if
a_transposed
:
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -116,7 +116,6 @@ def tl_matmul(
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
k_pack
*
micro_size_k
))):
# Load A into fragment
mfma_emitter
.
ldmatrix_a
(
A_local
,
...
...
@@ -160,17 +159,8 @@ def tl_matmul(
return
main
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
)
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
...
...
@@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M,
if
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
a_transposed
and
not
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
not
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
else
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
...
...
@@ -228,16 +215,13 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"float8_e4m3fnuz"
,
"float16"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
k_pack
=
2
)
if
__name__
==
"__main__"
:
...
...
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
View file @
29051439
...
...
@@ -23,7 +23,6 @@ def tl_matmul(
b_preshuffle
=
False
,
b_g2l_load
=
False
,
):
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
...
...
@@ -53,18 +52,21 @@ def tl_matmul(
A_shape
=
(
K
,
M
)
if
a_transposed
else
(
M
,
K
)
if
b_preshuffle
:
B_shape
=
(
N
//
micro_size_y
,
K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
if
b_transposed
else
(
K
//
pack_size_k
,
N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
B_shape
=
(
(
N
//
micro_size_y
,
K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
if
b_transposed
else
(
K
//
pack_size_k
,
N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
)
else
:
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
a_transposed
else
(
block_M
,
block_K
)
if
b_preshuffle
:
B_shared_shape
=
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
if
b_transposed
else
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
B_shared_shape
=
(
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
if
b_transposed
else
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
)
else
:
B_shared_shape
=
(
block_N
,
block_K
)
if
b_transposed
else
(
block_K
,
block_N
)
...
...
@@ -94,21 +96,22 @@ def tl_matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
A_shared
:
make_swizzle_layout
(
A_shared
),
})
T
.
annotate_layout
(
{
A_shared
:
make_swizzle_layout
(
A_shared
),
}
)
num_ko
=
K
//
block_K
num_ki
=
block_K
//
(
k_pack
*
micro_size_k
)
...
...
@@ -119,7 +122,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
num_ko
,
num_stages
=
0
):
# Load A into shared memory
if
a_transposed
:
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
...
...
@@ -129,20 +131,13 @@ def tl_matmul(
# Load B into shared memory
if
b_g2l_load
is
False
:
if
b_transposed
:
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
):
B_shared
[
j
,
k
,
jj
,
kk
]
=
B
[
bx
*
block_N
//
micro_size_y
+
j
,
ko
*
block_K
//
pack_size_k
+
k
,
jj
,
kk
]
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
):
B_shared
[
j
,
k
,
jj
,
kk
]
=
B
[
bx
*
block_N
//
micro_size_y
+
j
,
ko
*
block_K
//
pack_size_k
+
k
,
jj
,
kk
]
else
:
for
k
,
j
,
kk
,
jj
in
T
.
Parallel
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
):
B_shared
[
k
,
j
,
kk
,
jj
]
=
B
[
ko
*
block_K
//
pack_size_k
+
k
,
bx
*
block_N
//
micro_size_y
+
j
,
kk
,
jj
]
for
k
,
j
,
kk
,
jj
in
T
.
Parallel
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
):
B_shared
[
k
,
j
,
kk
,
jj
]
=
B
[
ko
*
block_K
//
pack_size_k
+
k
,
bx
*
block_N
//
micro_size_y
+
j
,
kk
,
jj
]
for
ki
in
T
.
serial
(
0
,
num_ki
):
# Load A S2L
mfma_emitter
.
ldmatrix_a
(
A_local
,
...
...
@@ -176,10 +171,10 @@ def tl_matmul(
def
shuffle_weight
(
x
:
torch
.
Tensor
,
layout
=
(
16
,
32
),
k_pack
=
1
,
is_transpose
=
False
,
x
:
torch
.
Tensor
,
layout
=
(
16
,
32
),
k_pack
=
1
,
is_transpose
=
False
,
)
->
torch
.
Tensor
:
IN
,
IK
=
layout
BK
=
IK
*
k_pack
...
...
@@ -194,19 +189,20 @@ def shuffle_weight(
return
x
.
contiguous
()
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
,
b_preshuffle
=
False
,
b_g2l_load
=
False
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
,
b_preshuffle
,
b_g2l_load
)
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
,
b_preshuffle
=
False
,
b_g2l_load
=
False
,
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
,
b_preshuffle
,
b_g2l_load
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
...
...
@@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M,
if
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
a_transposed
and
not
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
not
a_transposed
and
b_transposed
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
else
:
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
...
...
@@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M,
@
tilelang
.
testing
.
requires_rocm
def
test_assert_tl_matmul
():
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_transposed
=
False
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_transposed
=
False
,
b_preshuffle
=
True
)
if
__name__
==
"__main__"
:
...
...
testing/python/amd/test_tilelang_test_amd.py
View file @
29051439
...
...
@@ -27,8 +27,7 @@ def matmul(
vec_size
=
4
*
k_pack
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
(
M
,
N
),
out_dtype
)):
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
...
@@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt():
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
@
tilelang
.
testing
.
requires_rocm
...
...
@@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32():
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
def
matmul_rs
(
...
...
@@ -149,9 +146,9 @@ def matmul_rs(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
testing/python/analysis/test_tilelang_fragment_loop_checker.py
View file @
29051439
...
...
@@ -5,14 +5,12 @@ import pytest
@
tilelang
.
jit
def
simple_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
def
simple_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
...
@@ -28,14 +26,12 @@ def simple_invalid_loop(dtype: str = "bfloat16",
@
tilelang
.
jit
def
nested_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
def
nested_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
...
@@ -52,14 +48,12 @@ def nested_invalid_loop(dtype: str = "bfloat16",
@
tilelang
.
jit
def
invalid_loop_with_complex_dataflow
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
def
invalid_loop_with_complex_dataflow
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
...
@@ -75,14 +69,12 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
@
tilelang
.
jit
def
valid_loop_not_use_loop_var
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
def
valid_loop_not_use_loop_var
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
...
@@ -99,14 +91,12 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
@
tilelang
.
jit
def
valid_loop_not_frag
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
def
valid_loop_not_frag
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
...
...
@@ -122,14 +112,12 @@ def valid_loop_not_frag(dtype: str = "bfloat16",
@
tilelang
.
jit
def
valid_loop_serial
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
def
valid_loop_serial
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
...
...
testing/python/analysis/test_tilelang_nested_loop_checker.py
View file @
29051439
...
...
@@ -30,11 +30,10 @@ Rule:
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_parallels
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
...
...
@@ -46,29 +45,26 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_triple_continuous_parallels
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block1
//
block2
):
for
j
in
T
.
Parallel
(
block1
):
for
k
in
T
.
Parallel
(
block2
):
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
return
main
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_noncontinuous_parallels
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
...
...
@@ -103,8 +99,9 @@ is OK.
"""
def
matmul_nested_pipelines
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
threads
,
order
,
stage
,
extra_pipeline_repeats
):
def
matmul_nested_pipelines
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
threads
,
order
,
stage
,
extra_pipeline_repeats
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
...
...
@@ -114,9 +111,9 @@ def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -180,7 +177,8 @@ def run_gemm_nested_pipelines(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
@@ -193,8 +191,8 @@ def run_gemm_nested_pipelines(
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
...
...
@@ -218,11 +216,10 @@ is OK.
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_serials
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block
):
...
...
@@ -234,11 +231,10 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_noncontinuous_serials
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block
):
...
...
@@ -277,11 +273,10 @@ Rule:
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_sp
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block
):
...
...
@@ -293,11 +288,10 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_ps
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
...
...
@@ -309,36 +303,32 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_psp
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block1
//
block2
):
for
j
in
T
.
serial
(
block1
):
for
k
in
T
.
Parallel
(
block2
):
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
return
main
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_sps
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block1
//
block2
):
for
j
in
T
.
Parallel
(
block1
):
for
k
in
T
.
serial
(
block2
):
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
return
main
...
...
@@ -399,9 +389,9 @@ def matmul_nested_pipa(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -444,9 +434,9 @@ def matmul_nested_papipa(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -505,7 +495,8 @@ def run_gemm_mixed_pp(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
@@ -514,8 +505,8 @@ def run_gemm_mixed_pp(
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
...
...
@@ -543,7 +534,8 @@ def run_gemm_mixed_pp(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
def
test_mixed_pp
():
...
...
@@ -576,9 +568,9 @@ def matmul_with_parallel(
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
@@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
...
...
@@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel(
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
...
...
@@ -675,16 +668,16 @@ def run_gemm_tiled_op_with_parallel(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
tir_op_with_parallel
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
...
...
@@ -696,11 +689,10 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
customize_op_with_parallel
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
...
...
testing/python/autotune/test_tilelang_autotune.py
View file @
29051439
...
...
@@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False):
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.roller.rasterization
import
NoRasterization
arch
=
CUDA
(
"cuda"
)
topk
=
20
...
...
@@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False):
for
config
in
configs
:
print
(
config
)
else
:
block_M
=
[
64
]
block_N
=
[
64
]
block_K
=
[
32
]
...
...
@@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False):
num_stages
,
thread_num
,
enable_rasterization
,
))
)
)
configs
=
[
{
...
...
@@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False):
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
return
configs
...
...
@@ -190,9 +192,9 @@ def matmul(M, N, K, with_roller):
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
@@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller):
return
main
autotuner
=
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
autotuner
=
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
))
.
set_compile_args
(
out_idx
=
[
-
1
],
target
=
"auto"
,
).
set_profile_args
(
ref_prog
=
ref_program
,)
)
.
set_profile_args
(
ref_prog
=
ref_program
,
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
...
testing/python/autotune/test_tilelang_autotune_with_inputs.py
View file @
29051439
...
...
@@ -30,38 +30,23 @@ def ref_program(A, B):
def
get_configs
():
iter_params
=
dict
(
block_M
=
[
64
],
block_N
=
[
64
],
block_K
=
[
32
],
num_stages
=
[
0
,
1
],
thread_num
=
[
128
],
enable_rasterization
=
[
False
])
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
0
,
thread_num
=
128
,
enable_rasterization
=
False
):
iter_params
=
dict
(
block_M
=
[
64
],
block_N
=
[
64
],
block_K
=
[
32
],
num_stages
=
[
0
,
1
],
thread_num
=
[
128
],
enable_rasterization
=
[
False
])
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
0
,
thread_num
=
128
,
enable_rasterization
=
False
):
dtype
=
"float16"
accum_dtype
=
"float"
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
"""
The compiled TVM function for block-level matrix multiplication.
...
...
@@ -76,7 +61,6 @@ def matmul(M,
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
testing/python/cache/test_tilelang_cache_matmul.py
View file @
29051439
...
...
@@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
...
@@ -63,6 +63,7 @@ def run_cache_matmul():
Reference PyTorch matrix multiplication for comparison.
"""
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
half
)
# Assuming dtype="float16" in matmul
return
C
...
...
testing/python/carver/test_tilelang_carver_cuda_driver_properties.py
View file @
29051439
...
...
@@ -29,9 +29,7 @@ class _cudaDeviceAttrNames:
def
test_driver_get_device_properties
():
prop
=
get_cuda_device_properties
()
assert
prop
is
not
None
,
"Failed to get CUDA device properties"
assert
isinstance
(
prop
,
torch
.
cuda
.
_CudaDeviceProperties
),
(
"Returned object is not of type _CudaDeviceProperties"
)
assert
isinstance
(
prop
,
torch
.
cuda
.
_CudaDeviceProperties
),
"Returned object is not of type _CudaDeviceProperties"
def
test_device_get_device_name
():
...
...
@@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block():
def
test_device_get_persisting_l2_cache_size
():
tl_cache_size
=
get_persisting_l2_cache_max_size
()
driver_cache_size
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxPersistingL2CacheSize
)
driver_cache_size
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxPersistingL2CacheSize
)
assert
tl_cache_size
==
driver_cache_size
,
"Persisting L2 cache size values do not match"
...
...
@@ -61,17 +58,14 @@ def test_device_get_num_sms():
def
test_device_get_registers_per_block
():
tl_regs_per_block
=
get_registers_per_block
()
driver_regs_per_block
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxRegistersPerBlock
)
driver_regs_per_block
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxRegistersPerBlock
)
assert
tl_regs_per_block
==
driver_regs_per_block
,
"Registers per block values do not match"
def
test_device_get_max_dynamic_shared_size_bytes
():
tl_dynamic_smem
=
get_max_dynamic_shared_size_bytes
()
driver_dynamic_smem
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
)
assert
tl_dynamic_smem
==
driver_dynamic_smem
,
(
"Max dynamic shared size bytes values do not match"
)
driver_dynamic_smem
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
)
assert
tl_dynamic_smem
==
driver_dynamic_smem
,
"Max dynamic shared size bytes values do not match"
if
__name__
==
"__main__"
:
...
...
testing/python/carver/test_tilelang_carver_generate_hints.py
View file @
29051439
...
...
@@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
arch
=
auto_infer_current_arch
()
def
gemm
(
M
,
N
,
K
):
A
=
te
.
placeholder
((
M
,
K
),
name
=
'A'
,
dtype
=
'
float16
'
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
'B'
,
dtype
=
'
float16
'
)
A
=
te
.
placeholder
((
M
,
K
),
name
=
"A"
,
dtype
=
"
float16
"
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
"B"
,
dtype
=
"
float16
"
)
# Describe the matrix multiplication in TE
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
'k'
)
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
"k"
)
C
=
te
.
compute
(
(
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
'float16'
)
*
B
[
j
,
k
].
astype
(
'float16'
),
axis
=
[
k
]),
name
=
'C'
)
C
=
te
.
compute
((
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
"float16"
)
*
B
[
j
,
k
].
astype
(
"float16"
),
axis
=
[
k
]),
name
=
"C"
)
return
A
,
B
,
C
...
...
@@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
tensorized_func
,
tags
=
carver
.
utils
.
get_tensorized_func_and_tags
(
func
,
arch
.
target
)
print
(
tags
)
policy
=
carver
.
TensorCorePolicy
.
from_prim_func
(
func
=
tensorized_func
,
arch
=
arch
,
tags
=
tags
,
name
=
"matmul_0"
)
policy
=
carver
.
TensorCorePolicy
.
from_prim_func
(
func
=
tensorized_func
,
arch
=
arch
,
tags
=
tags
,
name
=
"matmul_0"
)
hints
=
policy
.
emit_config
(
topk
=
topk
)
...
...
@@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
arch
=
auto_infer_current_arch
()
def
gemm
(
M
,
N
,
K
):
A
=
te
.
placeholder
((
M
,
K
),
name
=
'A'
,
dtype
=
'
float16
'
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
'B'
,
dtype
=
'
float16
'
)
A
=
te
.
placeholder
((
M
,
K
),
name
=
"A"
,
dtype
=
"
float16
"
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
"B"
,
dtype
=
"
float16
"
)
# Describe the matrix multiplication in TE
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
'k'
)
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
"k"
)
C
=
te
.
compute
(
(
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
'float16'
)
*
B
[
j
,
k
].
astype
(
'float16'
),
axis
=
[
k
]),
name
=
'C'
)
C
=
te
.
compute
((
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
"float16"
)
*
B
[
j
,
k
].
astype
(
"float16"
),
axis
=
[
k
]),
name
=
"C"
)
return
A
,
B
,
C
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment