Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
279 additions
and
398 deletions
+279
-398
maint/host_checks/08_device_id_mismatch.py
maint/host_checks/08_device_id_mismatch.py
+2
-2
maint/host_checks/09_null_data_pointer.py
maint/host_checks/09_null_data_pointer.py
+1
-0
maint/host_checks/10_scalar_type_mismatch.py
maint/host_checks/10_scalar_type_mismatch.py
+2
-2
maint/host_checks/common.py
maint/host_checks/common.py
+4
-13
maint/precision/compare_ops.py
maint/precision/compare_ops.py
+27
-43
maint/scripts/ci_performance.py
maint/scripts/ci_performance.py
+18
-25
maint/scripts/performance.py
maint/scripts/performance.py
+22
-18
pyproject.toml
pyproject.toml
+10
-5
requirements-lint.txt
requirements-lint.txt
+0
-1
testing/conftest.py
testing/conftest.py
+2
-5
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+19
-35
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
+52
-82
testing/python/amd/test_tilelang_test_amd.py
testing/python/amd/test_tilelang_test_amd.py
+6
-9
testing/python/analysis/test_tilelang_fragment_loop_checker.py
...ng/python/analysis/test_tilelang_fragment_loop_checker.py
+12
-24
testing/python/analysis/test_tilelang_nested_loop_checker.py
testing/python/analysis/test_tilelang_nested_loop_checker.py
+56
-64
testing/python/autotune/test_tilelang_autotune.py
testing/python/autotune/test_tilelang_autotune.py
+17
-13
testing/python/autotune/test_tilelang_autotune_with_inputs.py
...ing/python/autotune/test_tilelang_autotune_with_inputs.py
+11
-27
testing/python/cache/test_tilelang_cache_matmul.py
testing/python/cache/test_tilelang_cache_matmul.py
+4
-3
testing/python/carver/test_tilelang_carver_cuda_driver_properties.py
...hon/carver/test_tilelang_carver_cuda_driver_properties.py
+5
-11
testing/python/carver/test_tilelang_carver_generate_hints.py
testing/python/carver/test_tilelang_carver_generate_hints.py
+9
-16
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
maint/host_checks/08_device_id_mismatch.py
View file @
29051439
"""Reproduce: device_id mismatch (requires >=2 CUDA devices).
"""Reproduce: device_id mismatch (requires >=2 CUDA devices).
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/09_null_data_pointer.py
View file @
29051439
...
@@ -7,6 +7,7 @@ or a host-side non-NULL pointer check.
...
@@ -7,6 +7,7 @@ or a host-side non-NULL pointer check.
Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script
Note: Constructing a true DLTensor with NULL data in PyTorch is not typical; this script
demonstrates passing None, which still reproduces the intended class of failure.
demonstrates passing None, which still reproduces the intended class of failure.
"""
"""
import
torch
import
torch
from
common
import
build_matmul_kernel
from
common
import
build_matmul_kernel
...
...
maint/host_checks/10_scalar_type_mismatch.py
View file @
29051439
"""Reproduce: scalar parameter type mismatch (int/bool).
"""Reproduce: scalar parameter type mismatch (int/bool).
"""
"""
from
common
import
build_scalar_check_kernel
from
common
import
build_scalar_check_kernel
...
...
maint/host_checks/common.py
View file @
29051439
...
@@ -3,20 +3,12 @@ import tilelang.language as T
...
@@ -3,20 +3,12 @@ import tilelang.language as T
import
torch
import
torch
def
make_matmul_prim
(
M
,
def
make_matmul_prim
(
M
,
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
dtype
=
"float16"
,
accum_dtype
=
"float"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"):
...
@@ -42,7 +34,6 @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"):
def
build_scalar_check_kernel
(
target
=
"cuda"
):
def
build_scalar_check_kernel
(
target
=
"cuda"
):
@
T
.
prim_func
@
T
.
prim_func
def
scalar_check
(
x
:
T
.
int32
,
flag
:
T
.
bool
()):
def
scalar_check
(
x
:
T
.
int32
,
flag
:
T
.
bool
()):
T
.
evaluate
(
0
)
T
.
evaluate
(
0
)
...
...
maint/precision/compare_ops.py
View file @
29051439
...
@@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = {
...
@@ -37,7 +37,7 @@ OP_NAMES: Dict[int, str] = {
6
:
"sqrt"
,
6
:
"sqrt"
,
7
:
"tanh"
,
7
:
"tanh"
,
8
:
"rsqrt"
,
8
:
"rsqrt"
,
9
:
"inv_sqrt"
9
:
"inv_sqrt"
,
}
}
# Block sizes for kernels
# Block sizes for kernels
...
@@ -49,8 +49,7 @@ TILELANG_THREADS = 128
...
@@ -49,8 +49,7 @@ TILELANG_THREADS = 128
def
parse_arguments
()
->
argparse
.
Namespace
:
def
parse_arguments
()
->
argparse
.
Namespace
:
"""Parse command line arguments."""
"""Parse command line arguments."""
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"Precision comparison tool for various CUDA implementations"
)
description
=
"Precision comparison tool for various CUDA implementations"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1000000
,
help
=
"Number of elements to test"
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1000000
,
help
=
"Number of elements to test"
)
parser
.
add_argument
(
"--low"
,
type
=
float
,
default
=-
4.0
,
help
=
"Lower bound for random values"
)
parser
.
add_argument
(
"--low"
,
type
=
float
,
default
=-
4.0
,
help
=
"Lower bound for random values"
)
parser
.
add_argument
(
"--high"
,
type
=
float
,
default
=
4.0
,
help
=
"Upper bound for random values"
)
parser
.
add_argument
(
"--high"
,
type
=
float
,
default
=
4.0
,
help
=
"Upper bound for random values"
)
...
@@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module:
...
@@ -67,7 +66,7 @@ def initialize_cuda() -> torch.nn.Module:
return
load
(
return
load
(
name
=
"cuda_ops"
,
name
=
"cuda_ops"
,
sources
=
[
"cuda_ops.cu"
],
sources
=
[
"cuda_ops.cu"
],
extra_cuda_cflags
=
[]
# No fast_math flags
extra_cuda_cflags
=
[]
,
# No fast_math flags
)
)
...
@@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S
...
@@ -149,8 +148,7 @@ def triton_unary_kernel(x_ptr, out_ptr, n_elements, op_id: tl.constexpr, BLOCK_S
@
triton
.
jit
@
triton
.
jit
def
triton_libdevice_unary_kernel
(
x_ptr
,
out_ptr
,
n_elements
,
op_id
:
tl
.
constexpr
,
def
triton_libdevice_unary_kernel
(
x_ptr
,
out_ptr
,
n_elements
,
op_id
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
):
BLOCK_SIZE
:
tl
.
constexpr
):
"""LibDevice Triton kernel for unary operations."""
"""LibDevice Triton kernel for unary operations."""
pid
=
tl
.
program_id
(
0
)
pid
=
tl
.
program_id
(
0
)
block_start
=
pid
*
BLOCK_SIZE
block_start
=
pid
*
BLOCK_SIZE
...
@@ -188,13 +186,10 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool =
...
@@ -188,13 +186,10 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool =
@
T
.
prim_func
@
T
.
prim_func
def
tilelang_unary_kernel
(
def
tilelang_unary_kernel
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
TILELANG_BLOCK_M
,
TILELANG_BLOCK_N
):
for
i
,
j
in
T
.
Parallel
(
TILELANG_BLOCK_M
,
TILELANG_BLOCK_N
):
row
=
by
*
TILELANG_BLOCK_M
+
i
row
=
by
*
TILELANG_BLOCK_M
+
i
col
=
bx
*
TILELANG_BLOCK_N
+
j
col
=
bx
*
TILELANG_BLOCK_N
+
j
...
@@ -229,14 +224,11 @@ def make_tilelang_binary_kernel(M: int, N: int):
...
@@ -229,14 +224,11 @@ def make_tilelang_binary_kernel(M: int, N: int):
@
T
.
prim_func
@
T
.
prim_func
def
tilelang_binary_kernel
(
def
tilelang_binary_kernel
(
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
A
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
B
:
T
.
Tensor
((
M
,
N
),
"float32"
),
C
:
T
.
Tensor
((
M
,
N
),
"float32"
),
C
:
T
.
Tensor
((
M
,
N
),
"float32"
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
TILELANG_BLOCK_N
),
T
.
ceildiv
(
M
,
TILELANG_BLOCK_M
),
threads
=
TILELANG_THREADS
)
as
(
bx
,
by
):
for
i
,
j
in
T
.
Parallel
(
TILELANG_BLOCK_M
,
TILELANG_BLOCK_N
):
for
i
,
j
in
T
.
Parallel
(
TILELANG_BLOCK_M
,
TILELANG_BLOCK_N
):
row
=
by
*
TILELANG_BLOCK_M
+
i
row
=
by
*
TILELANG_BLOCK_M
+
i
col
=
bx
*
TILELANG_BLOCK_N
+
j
col
=
bx
*
TILELANG_BLOCK_N
+
j
...
@@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int):
...
@@ -247,10 +239,7 @@ def make_tilelang_binary_kernel(M: int, N: int):
return
tilelang_binary_kernel
return
tilelang_binary_kernel
def
tilelang_op
(
x
:
torch
.
Tensor
,
def
tilelang_op
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fastmath
:
bool
=
False
)
->
torch
.
Tensor
:
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fastmath
:
bool
=
False
)
->
torch
.
Tensor
:
"""TileLang operation interface."""
"""TileLang operation interface."""
assert
x
.
is_cuda
assert
x
.
is_cuda
...
@@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor,
...
@@ -272,7 +261,8 @@ def tilelang_op(x: torch.Tensor,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
use_fastmath
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
use_fastmath
,
})
},
)
out
=
kernel
(
x
,
y
)
out
=
kernel
(
x
,
y
)
else
:
# Unary operation
else
:
# Unary operation
kernel_func
=
make_tilelang_unary_kernel
(
M
,
N
,
op_id
,
use_fastmath
)
kernel_func
=
make_tilelang_unary_kernel
(
M
,
N
,
op_id
,
use_fastmath
)
...
@@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor,
...
@@ -282,7 +272,8 @@ def tilelang_op(x: torch.Tensor,
target
=
"cuda"
,
target
=
"cuda"
,
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
use_fastmath
,
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
use_fastmath
,
})
},
)
out
=
kernel
(
x
)
out
=
kernel
(
x
)
# Restore original shape
# Restore original shape
...
@@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
...
@@ -293,7 +284,7 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
"""Standard Triton operation interface."""
"""Standard Triton operation interface."""
assert
x
.
is_cuda
assert
x
.
is_cuda
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
'
BLOCK_SIZE
'
]
-
1
)
//
meta
[
'
BLOCK_SIZE
'
],)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
"
BLOCK_SIZE
"
]
-
1
)
//
meta
[
"
BLOCK_SIZE
"
],)
if
op_id
==
0
:
# Division - binary operation
if
op_id
==
0
:
# Division - binary operation
assert
y
is
not
None
,
"Division operation requires second operand"
assert
y
is
not
None
,
"Division operation requires second operand"
...
@@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
...
@@ -304,13 +295,11 @@ def triton_op(x: torch.Tensor, op_id: int, y: Optional[torch.Tensor] = None) ->
return
out
return
out
def
triton_libdevice_op
(
x
:
torch
.
Tensor
,
def
triton_libdevice_op
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""LibDevice Triton operation interface."""
"""LibDevice Triton operation interface."""
assert
x
.
is_cuda
assert
x
.
is_cuda
out
=
torch
.
empty_like
(
x
)
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
'
BLOCK_SIZE
'
]
-
1
)
//
meta
[
'
BLOCK_SIZE
'
],)
grid
=
lambda
meta
:
((
x
.
numel
()
+
meta
[
"
BLOCK_SIZE
"
]
-
1
)
//
meta
[
"
BLOCK_SIZE
"
],)
if
op_id
==
0
:
# Division - binary operation
if
op_id
==
0
:
# Division - binary operation
assert
y
is
not
None
,
"Division operation requires second operand"
assert
y
is
not
None
,
"Division operation requires second operand"
...
@@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor,
...
@@ -321,9 +310,7 @@ def triton_libdevice_op(x: torch.Tensor,
return
out
return
out
def
get_pytorch_reference
(
x
:
torch
.
Tensor
,
def
get_pytorch_reference
(
x
:
torch
.
Tensor
,
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
op_id
:
int
,
y
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Get PyTorch reference implementation for the given operation."""
"""Get PyTorch reference implementation for the given operation."""
if
op_id
==
0
:
if
op_id
==
0
:
assert
y
is
not
None
,
"Division requires second operand"
assert
y
is
not
None
,
"Division requires second operand"
...
@@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T
...
@@ -362,8 +349,10 @@ def summarize_error(tag: str, output: Optional[torch.Tensor], reference: torch.T
abs_err
=
(
output_double
-
reference_double
).
abs
()
abs_err
=
(
output_double
-
reference_double
).
abs
()
rel_err
=
abs_err
/
(
reference_double
.
abs
().
clamp_min
(
1e-30
))
rel_err
=
abs_err
/
(
reference_double
.
abs
().
clamp_min
(
1e-30
))
print
(
f
"
{
tag
:
<
32
}
max abs:
{
abs_err
.
max
():.
3
e
}
, mean abs:
{
abs_err
.
mean
():.
3
e
}
, "
print
(
f
"max rel:
{
rel_err
.
max
():.
3
e
}
, mean rel:
{
rel_err
.
mean
():.
3
e
}
"
)
f
"
{
tag
:
<
32
}
max abs:
{
abs_err
.
max
():.
3
e
}
, mean abs:
{
abs_err
.
mean
():.
3
e
}
, "
f
"max rel:
{
rel_err
.
max
():.
3
e
}
, mean rel:
{
rel_err
.
mean
():.
3
e
}
"
)
# Precision comparison function
# Precision comparison function
...
@@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
...
@@ -407,9 +396,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
results
[
name
]
=
None
results
[
name
]
=
None
# Print comparison header
# Print comparison header
print
(
print
(
f
"
{
'Implementation'
:
<
32
}
{
'Max Abs Error'
:
<
19
}
{
'Mean Abs Error'
:
<
20
}
{
'Max Rel Error'
:
<
19
}
{
'Mean Rel Error'
}
"
)
f
"
{
'Implementation'
:
<
32
}
{
'Max Abs Error'
:
<
19
}
{
'Mean Abs Error'
:
<
20
}
{
'Max Rel Error'
:
<
19
}
{
'Mean Rel Error'
}
"
)
print
(
"-"
*
90
)
print
(
"-"
*
90
)
# Compare all implementations against double precision reference
# Compare all implementations against double precision reference
...
@@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
...
@@ -427,8 +414,7 @@ def compare(op_id: int, x: torch.Tensor, y: Optional[torch.Tensor] = None) -> No
summarize_error
(
tag
,
output
,
ref_double
)
summarize_error
(
tag
,
output
,
ref_double
)
def
generate_test_data
(
op_id
:
int
,
n
:
int
,
device
:
torch
.
device
,
low
:
float
,
def
generate_test_data
(
op_id
:
int
,
n
:
int
,
device
:
torch
.
device
,
low
:
float
,
high
:
float
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
high
:
float
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Generate appropriate test data for each operation."""
"""Generate appropriate test data for each operation."""
if
op_id
==
0
:
# Division
if
op_id
==
0
:
# Division
x
=
torch
.
empty
(
n
,
device
=
device
).
uniform_
(
low
,
high
)
x
=
torch
.
empty
(
n
,
device
=
device
).
uniform_
(
low
,
high
)
...
@@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float,
...
@@ -450,9 +436,7 @@ def generate_test_data(op_id: int, n: int, device: torch.device, low: float,
def
main
()
->
None
:
def
main
()
->
None
:
"""Main execution function."""
"""Main execution function."""
print
(
print
(
"Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang"
)
"Precision comparison between CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang"
)
print
(
"="
*
90
)
print
(
"="
*
90
)
for
op_id
in
range
(
len
(
OP_NAMES
)):
for
op_id
in
range
(
len
(
OP_NAMES
)):
...
...
maint/scripts/ci_performance.py
View file @
29051439
...
@@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1"
...
@@ -10,39 +10,32 @@ env["TILELANG_CLEAR_CACHE"] = "1"
def
parse_output
(
output
):
def
parse_output
(
output
):
data
=
{}
data
=
{}
for
line
in
output
.
split
(
'
\n
'
):
for
line
in
output
.
split
(
"
\n
"
):
line
=
line
.
strip
()
line
=
line
.
strip
()
if
line
.
startswith
(
'
Latency:
'
):
if
line
.
startswith
(
"
Latency:
"
):
match
=
re
.
search
(
r
'
Latency: ([\d.]+)
'
,
line
)
match
=
re
.
search
(
r
"
Latency: ([\d.]+)
"
,
line
)
data
[
'
latency
'
]
=
match
.
group
(
1
)
if
match
else
'
N/A
'
data
[
"
latency
"
]
=
match
.
group
(
1
)
if
match
else
"
N/A
"
elif
line
.
startswith
(
'
TFlops:
'
):
elif
line
.
startswith
(
"
TFlops:
"
):
match
=
re
.
search
(
r
'
TFlops: ([\d.]+)
'
,
line
)
match
=
re
.
search
(
r
"
TFlops: ([\d.]+)
"
,
line
)
data
[
'
best_tflops
'
]
=
match
.
group
(
1
)
if
match
else
'
N/A
'
data
[
"
best_tflops
"
]
=
match
.
group
(
1
)
if
match
else
"
N/A
"
elif
line
.
startswith
(
'
Config:
'
):
elif
line
.
startswith
(
"
Config:
"
):
data
[
'
config
'
]
=
line
.
split
(
'
Config:
'
)[
-
1
]
data
[
"
config
"
]
=
line
.
split
(
"
Config:
"
)[
-
1
]
elif
line
.
startswith
(
'
Reference TFlops:
'
):
elif
line
.
startswith
(
"
Reference TFlops:
"
):
match
=
re
.
search
(
r
'
Reference TFlops: ([\d.]+)
'
,
line
)
match
=
re
.
search
(
r
"
Reference TFlops: ([\d.]+)
"
,
line
)
data
[
'
ref_tflops
'
]
=
match
.
group
(
1
)
if
match
else
'
N/A
'
data
[
"
ref_tflops
"
]
=
match
.
group
(
1
)
if
match
else
"
N/A
"
return
data
return
data
output_v1
=
subprocess
.
run
([
'./tl/bin/python'
,
'./maint/scripts/performance.py'
],
output_v1
=
subprocess
.
run
([
"./tl/bin/python"
,
"./maint/scripts/performance.py"
],
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
data_v1
=
parse_output
(
output_v1
)
data_v1
=
parse_output
(
output_v1
)
output_v2
=
subprocess
.
run
([
'./tll/bin/python'
,
'./maint/scripts/performance.py'
],
output_v2
=
subprocess
.
run
([
"./tll/bin/python"
,
"./maint/scripts/performance.py"
],
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
capture_output
=
True
,
text
=
True
,
env
=
env
).
stdout
data_v2
=
parse_output
(
output_v2
)
data_v2
=
parse_output
(
output_v2
)
table
=
[[
table
=
[
"original"
,
data_v1
[
'latency'
],
data_v1
[
'best_tflops'
],
data_v1
[
'ref_tflops'
],
data_v1
[
'config'
]
[
"original"
,
data_v1
[
"latency"
],
data_v1
[
"best_tflops"
],
data_v1
[
"ref_tflops"
],
data_v1
[
"config"
]],
],
[
[
"current"
,
data_v2
[
"latency"
],
data_v2
[
"best_tflops"
],
data_v2
[
"ref_tflops"
],
data_v2
[
"config"
]],
"current"
,
data_v2
[
'latency'
],
data_v2
[
'best_tflops'
],
data_v2
[
'ref_tflops'
],
data_v2
[
'config'
]
]
]]
headers
=
[
"version"
,
"Best Latency (s)"
,
"Best TFlops"
,
"Reference TFlops"
,
"Best Config"
]
headers
=
[
"version"
,
"Best Latency (s)"
,
"Best TFlops"
,
"Reference TFlops"
,
"Best Config"
]
...
...
maint/scripts/performance.py
View file @
29051439
...
@@ -8,19 +8,20 @@ def ref_program(A, B):
...
@@ -8,19 +8,20 @@ def ref_program(A, B):
def
get_configs
():
def
get_configs
():
configs
=
[{
configs
=
[
"block_M"
:
128
,
{
"block_N"
:
128
,
"block_M"
:
128
,
"block_K"
:
64
,
"block_N"
:
128
,
"num_stages"
:
2
,
"block_K"
:
64
,
"thread_num"
:
256
,
"num_stages"
:
2
,
"enable_rasteration"
:
True
,
# keep param name for backward-compat
"thread_num"
:
256
,
}]
"enable_rasteration"
:
True
,
# keep param name for backward-compat
}
]
return
configs
return
configs
def
run
(
M
,
N
,
K
):
def
run
(
M
,
N
,
K
):
def
kernel
(
def
kernel
(
block_M
=
None
,
block_M
=
None
,
block_N
=
None
,
block_N
=
None
,
...
@@ -34,12 +35,11 @@ def run(M, N, K):
...
@@ -34,12 +35,11 @@ def run(M, N, K):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_N
,
block_K
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -60,12 +60,16 @@ def run(M, N, K):
...
@@ -60,12 +60,16 @@ def run(M, N, K):
return
main
return
main
autotuner
=
AutoTuner
.
from_kernel
(
autotuner
=
(
kernel
=
kernel
,
configs
=
get_configs
()).
set_compile_args
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
())
.
set_compile_args
(
out_idx
=
[
-
1
],
out_idx
=
[
-
1
],
target
=
"auto"
,
target
=
"auto"
,
).
set_profile_args
(
)
ref_prog
=
ref_program
,)
.
set_profile_args
(
ref_prog
=
ref_program
,
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
...
pyproject.toml
View file @
29051439
...
@@ -122,10 +122,7 @@ tilelang = "tilelang"
...
@@ -122,10 +122,7 @@ tilelang = "tilelang"
"tilelang/3rdparty/composable_kernel/include"
=
"3rdparty/composable_kernel/include"
"tilelang/3rdparty/composable_kernel/include"
=
"3rdparty/composable_kernel/include"
"tilelang/3rdparty/composable_kernel/library"
=
"3rdparty/composable_kernel/library"
"tilelang/3rdparty/composable_kernel/library"
=
"3rdparty/composable_kernel/library"
[tool.yapf]
based_on_style
=
"yapf"
column_limit
=
100
indent_width
=
4
[tool.codespell]
[tool.codespell]
ignore-words
=
"docs/spelling_wordlist.txt"
ignore-words
=
"docs/spelling_wordlist.txt"
...
@@ -138,7 +135,7 @@ skip = [
...
@@ -138,7 +135,7 @@ skip = [
[tool.ruff]
[tool.ruff]
target-version
=
"py39"
target-version
=
"py39"
line-length
=
1
0
0
line-length
=
1
4
0
output-format
=
"full"
output-format
=
"full"
exclude
=
[
exclude
=
[
...
@@ -146,6 +143,14 @@ exclude = [
...
@@ -146,6 +143,14 @@ exclude = [
"examples/deepseek_v32/inference"
,
"examples/deepseek_v32/inference"
,
]
]
[tool.ruff.format]
quote-style
=
"double"
indent-style
=
"space"
skip-magic-trailing-comma
=
false
line-ending
=
"auto"
docstring-code-format
=
false
docstring-code-line-length
=
"dynamic"
[tool.ruff.lint.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Do not upgrade type hint in testing and examples.
# Do not upgrade type hint in testing and examples.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
...
...
requirements-lint.txt
View file @
29051439
...
@@ -4,4 +4,3 @@ clang-format==21.1.2
...
@@ -4,4 +4,3 @@ clang-format==21.1.2
clang-tidy==21.1.1
clang-tidy==21.1.1
codespell[toml]==2.4.1
codespell[toml]==2.4.1
ruff==0.14.3
ruff==0.14.3
yapf==0.43.0
testing/conftest.py
View file @
29051439
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
...
@@ -33,12 +33,9 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
"warnings"
,
"warnings"
,
"error"
,
"error"
,
}
}
if
(
sum
(
if
sum
(
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
:
len
(
terminalreporter
.
stats
.
get
(
k
,
[]))
for
k
in
known_types
.
difference
({
"skipped"
,
"deselected"
}))
==
0
):
terminalreporter
.
write_sep
(
terminalreporter
.
write_sep
(
"!"
,
"!"
,
(
f
"Error: No tests were collected. "
(
f
"Error: No tests were collected.
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
f
"
{
dict
(
sorted
((
k
,
len
(
v
))
for
k
,
v
in
terminalreporter
.
stats
.
items
()))
}
"
),
)
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
pytest
.
exit
(
"No tests were collected."
,
returncode
=
5
)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
29051439
...
@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
...
@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
from
tilelang.intrinsics.mfma_macro_generator
import
(
from
tilelang.intrinsics.mfma_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
MatrixCoreIntrinEmitter
,
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
tilelang
.
testing
.
set_random_seed
(
0
)
tilelang
.
testing
.
set_random_seed
(
0
)
...
@@ -22,7 +23,6 @@ def tl_matmul(
...
@@ -22,7 +23,6 @@ def tl_matmul(
b_transposed
=
True
,
b_transposed
=
True
,
k_pack
=
1
,
k_pack
=
1
,
):
):
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
...
@@ -78,12 +78,11 @@ def tl_matmul(
...
@@ -78,12 +78,11 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
C_shared
=
T
.
alloc_shared
(
C_shared_shape
,
out_dtype
,
scope
=
shared_scope
)
...
@@ -91,10 +90,12 @@ def tl_matmul(
...
@@ -91,10 +90,12 @@ def tl_matmul(
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
B_shared
:
make_swizzle_layout
(
B_shared
),
A_shared
:
make_swizzle_layout
(
A_shared
),
})
B_shared
:
make_swizzle_layout
(
B_shared
),
}
)
# Improve L2 Cache
# Improve L2 Cache
T
.
use_swizzle
(
panel_size
=
10
)
T
.
use_swizzle
(
panel_size
=
10
)
...
@@ -102,7 +103,6 @@ def tl_matmul(
...
@@ -102,7 +103,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
0
):
for
ko
in
T
.
Pipelined
((
K
//
block_K
),
num_stages
=
0
):
# Load A into shared memory
# Load A into shared memory
if
a_transposed
:
if
a_transposed
:
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
...
@@ -116,7 +116,6 @@ def tl_matmul(
...
@@ -116,7 +116,6 @@ def tl_matmul(
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
B
[
ko
*
block_K
,
bx
*
block_N
],
B_shared
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
k_pack
*
micro_size_k
))):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
k_pack
*
micro_size_k
))):
# Load A into fragment
# Load A into fragment
mfma_emitter
.
ldmatrix_a
(
mfma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
@@ -160,17 +159,8 @@ def tl_matmul(
...
@@ -160,17 +159,8 @@ def tl_matmul(
return
main
return
main
def
assert_tl_matmul_correctness
(
M
,
def
assert_tl_matmul_correctness
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
):
N
,
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
)
K
,
in_dtype
,
out_dtype
,
accum_dtype
=
"float32"
,
a_transposed
=
False
,
b_transposed
=
True
,
k_pack
=
1
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
)
print
(
matmul
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
src_code
=
kernel
.
get_kernel_source
()
...
@@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M,
...
@@ -201,16 +191,13 @@ def assert_tl_matmul_correctness(M,
if
a_transposed
and
b_transposed
:
if
a_transposed
and
b_transposed
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
a_transposed
and
not
b_transposed
:
elif
a_transposed
and
not
b_transposed
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
not
a_transposed
and
b_transposed
:
elif
not
a_transposed
and
b_transposed
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
else
:
else
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
...
@@ -228,16 +215,13 @@ def test_assert_tl_matmul():
...
@@ -228,16 +215,13 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"float8_e4m3fnuz"
,
"float16"
)
assert_tl_matmul_correctness
(
128
,
128
,
128
,
"float8_e4m3fnuz"
,
"float16"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
)
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
k_pack
=
2
)
128
,
256
,
256
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
k_pack
=
2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
View file @
29051439
...
@@ -23,7 +23,6 @@ def tl_matmul(
...
@@ -23,7 +23,6 @@ def tl_matmul(
b_preshuffle
=
False
,
b_preshuffle
=
False
,
b_g2l_load
=
False
,
b_g2l_load
=
False
,
):
):
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
if
in_dtype
in
{
"float8_e4m3fnuz"
,
"int8"
}:
...
@@ -53,18 +52,21 @@ def tl_matmul(
...
@@ -53,18 +52,21 @@ def tl_matmul(
A_shape
=
(
K
,
M
)
if
a_transposed
else
(
M
,
K
)
A_shape
=
(
K
,
M
)
if
a_transposed
else
(
M
,
K
)
if
b_preshuffle
:
if
b_preshuffle
:
B_shape
=
(
N
//
micro_size_y
,
K
//
pack_size_k
,
micro_size_y
,
B_shape
=
(
pack_size_k
)
if
b_transposed
else
(
K
//
pack_size_k
,
N
//
micro_size_y
,
(
N
//
micro_size_y
,
K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
pack_size_k
,
micro_size_y
)
if
b_transposed
else
(
K
//
pack_size_k
,
N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
)
else
:
else
:
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
B_shape
=
(
N
,
K
)
if
b_transposed
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
a_transposed
else
(
block_M
,
block_K
)
A_shared_shape
=
(
block_K
,
block_M
)
if
a_transposed
else
(
block_M
,
block_K
)
if
b_preshuffle
:
if
b_preshuffle
:
B_shared_shape
=
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
B_shared_shape
=
(
pack_size_k
)
if
b_transposed
else
(
block_K
//
pack_size_k
,
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
)
block_N
//
micro_size_y
,
pack_size_k
,
if
b_transposed
micro_size_y
)
else
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
)
)
else
:
else
:
B_shared_shape
=
(
block_N
,
block_K
)
if
b_transposed
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
b_transposed
else
(
block_K
,
block_N
)
...
@@ -94,21 +96,22 @@ def tl_matmul(
...
@@ -94,21 +96,22 @@ def tl_matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
A_local
=
T
.
alloc_local
((
warp_rows
*
local_size_a
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
B_local
=
T
.
alloc_local
((
warp_cols
*
local_size_b
),
in_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
C_local
=
T
.
alloc_local
((
warp_rows
*
warp_cols
*
local_size_c
),
accum_dtype
)
T
.
annotate_layout
({
T
.
annotate_layout
(
A_shared
:
make_swizzle_layout
(
A_shared
),
{
})
A_shared
:
make_swizzle_layout
(
A_shared
),
}
)
num_ko
=
K
//
block_K
num_ko
=
K
//
block_K
num_ki
=
block_K
//
(
k_pack
*
micro_size_k
)
num_ki
=
block_K
//
(
k_pack
*
micro_size_k
)
...
@@ -119,7 +122,6 @@ def tl_matmul(
...
@@ -119,7 +122,6 @@ def tl_matmul(
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
ko
in
T
.
Pipelined
(
num_ko
,
num_stages
=
0
):
for
ko
in
T
.
Pipelined
(
num_ko
,
num_stages
=
0
):
# Load A into shared memory
# Load A into shared memory
if
a_transposed
:
if
a_transposed
:
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
T
.
copy
(
A
[
ko
*
block_K
,
by
*
block_M
],
A_shared
)
...
@@ -129,20 +131,13 @@ def tl_matmul(
...
@@ -129,20 +131,13 @@ def tl_matmul(
# Load B into shared memory
# Load B into shared memory
if
b_g2l_load
is
False
:
if
b_g2l_load
is
False
:
if
b_transposed
:
if
b_transposed
:
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
for
j
,
k
,
jj
,
kk
in
T
.
Parallel
(
block_N
//
micro_size_y
,
block_K
//
pack_size_k
,
micro_size_y
,
pack_size_k
):
block_K
//
pack_size_k
,
micro_size_y
,
B_shared
[
j
,
k
,
jj
,
kk
]
=
B
[
bx
*
block_N
//
micro_size_y
+
j
,
ko
*
block_K
//
pack_size_k
+
k
,
jj
,
kk
]
pack_size_k
):
B_shared
[
j
,
k
,
jj
,
kk
]
=
B
[
bx
*
block_N
//
micro_size_y
+
j
,
ko
*
block_K
//
pack_size_k
+
k
,
jj
,
kk
]
else
:
else
:
for
k
,
j
,
kk
,
jj
in
T
.
Parallel
(
block_K
//
pack_size_k
,
for
k
,
j
,
kk
,
jj
in
T
.
Parallel
(
block_K
//
pack_size_k
,
block_N
//
micro_size_y
,
pack_size_k
,
micro_size_y
):
block_N
//
micro_size_y
,
pack_size_k
,
B_shared
[
k
,
j
,
kk
,
jj
]
=
B
[
ko
*
block_K
//
pack_size_k
+
k
,
bx
*
block_N
//
micro_size_y
+
j
,
kk
,
jj
]
micro_size_y
):
B_shared
[
k
,
j
,
kk
,
jj
]
=
B
[
ko
*
block_K
//
pack_size_k
+
k
,
bx
*
block_N
//
micro_size_y
+
j
,
kk
,
jj
]
for
ki
in
T
.
serial
(
0
,
num_ki
):
for
ki
in
T
.
serial
(
0
,
num_ki
):
# Load A S2L
# Load A S2L
mfma_emitter
.
ldmatrix_a
(
mfma_emitter
.
ldmatrix_a
(
A_local
,
A_local
,
...
@@ -176,10 +171,10 @@ def tl_matmul(
...
@@ -176,10 +171,10 @@ def tl_matmul(
def
shuffle_weight
(
def
shuffle_weight
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
layout
=
(
16
,
32
),
layout
=
(
16
,
32
),
k_pack
=
1
,
k_pack
=
1
,
is_transpose
=
False
,
is_transpose
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
IN
,
IK
=
layout
IN
,
IK
=
layout
BK
=
IK
*
k_pack
BK
=
IK
*
k_pack
...
@@ -194,19 +189,20 @@ def shuffle_weight(
...
@@ -194,19 +189,20 @@ def shuffle_weight(
return
x
.
contiguous
()
return
x
.
contiguous
()
def
assert_tl_matmul_correctness
(
M
,
def
assert_tl_matmul_correctness
(
N
,
M
,
K
,
N
,
in_dtype
,
K
,
out_dtype
,
in_dtype
,
accum_dtype
=
"float32"
,
out_dtype
,
a_transposed
=
False
,
accum_dtype
=
"float32"
,
b_transposed
=
True
,
a_transposed
=
False
,
k_pack
=
1
,
b_transposed
=
True
,
b_preshuffle
=
False
,
k_pack
=
1
,
b_g2l_load
=
False
):
b_preshuffle
=
False
,
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
b_g2l_load
=
False
,
k_pack
,
b_preshuffle
,
b_g2l_load
)
):
matmul
=
tl_matmul
(
M
,
N
,
K
,
in_dtype
,
out_dtype
,
accum_dtype
,
a_transposed
,
b_transposed
,
k_pack
,
b_preshuffle
,
b_g2l_load
)
print
(
matmul
)
print
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
kernel
=
tilelang
.
compile
(
matmul
)
src_code
=
kernel
.
get_kernel_source
()
src_code
=
kernel
.
get_kernel_source
()
...
@@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M,
...
@@ -244,16 +240,13 @@ def assert_tl_matmul_correctness(M,
if
a_transposed
and
b_transposed
:
if
a_transposed
and
b_transposed
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
ref_c
=
torch
.
matmul
(
A
.
T
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
a_transposed
and
not
b_transposed
:
elif
a_transposed
and
not
b_transposed
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
ref_c
=
torch
.
matmul
(
A
.
Tto
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
elif
not
a_transposed
and
b_transposed
:
elif
not
a_transposed
and
b_transposed
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
B
.
T
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
else
:
else
:
# Get Reference Result
# Get Reference Result
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
ref_c
=
torch
.
matmul
(
A
.
to
(
torch
.
float32
),
B
.
to
(
torch
.
float32
)).
to
(
getattr
(
torch
,
out_dtype
))
...
@@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M,
...
@@ -266,40 +259,17 @@ def assert_tl_matmul_correctness(M,
@
tilelang
.
testing
.
requires_rocm
@
tilelang
.
testing
.
requires_rocm
def
test_assert_tl_matmul
():
def
test_assert_tl_matmul
():
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_transposed
=
False
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_transposed
=
False
,
b_preshuffle
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
testing/python/amd/test_tilelang_test_amd.py
View file @
29051439
...
@@ -27,8 +27,7 @@ def matmul(
...
@@ -27,8 +27,7 @@ def matmul(
vec_size
=
4
*
k_pack
vec_size
=
4
*
k_pack
@
T
.
prim_func
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
)):
(
M
,
N
),
out_dtype
)):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
@@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt():
...
@@ -111,8 +110,7 @@ def test_gemm_bf16f32f32_nt():
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
@
tilelang
.
testing
.
requires_rocm
@
tilelang
.
testing
.
requires_rocm
...
@@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32():
...
@@ -121,8 +119,7 @@ def test_gemm_bf16bf16f32():
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
def
matmul_rs
(
def
matmul_rs
(
...
@@ -149,9 +146,9 @@ def matmul_rs(
...
@@ -149,9 +146,9 @@ def matmul_rs(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
...
testing/python/analysis/test_tilelang_fragment_loop_checker.py
View file @
29051439
...
@@ -5,14 +5,12 @@ import pytest
...
@@ -5,14 +5,12 @@ import pytest
@
tilelang
.
jit
@
tilelang
.
jit
def
simple_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
def
simple_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
@@ -28,14 +26,12 @@ def simple_invalid_loop(dtype: str = "bfloat16",
...
@@ -28,14 +26,12 @@ def simple_invalid_loop(dtype: str = "bfloat16",
@
tilelang
.
jit
@
tilelang
.
jit
def
nested_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
def
nested_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
@@ -52,14 +48,12 @@ def nested_invalid_loop(dtype: str = "bfloat16",
...
@@ -52,14 +48,12 @@ def nested_invalid_loop(dtype: str = "bfloat16",
@
tilelang
.
jit
@
tilelang
.
jit
def
invalid_loop_with_complex_dataflow
(
dtype
:
str
=
"bfloat16"
,
def
invalid_loop_with_complex_dataflow
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
@@ -75,14 +69,12 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
...
@@ -75,14 +69,12 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
@
tilelang
.
jit
@
tilelang
.
jit
def
valid_loop_not_use_loop_var
(
dtype
:
str
=
"bfloat16"
,
def
valid_loop_not_use_loop_var
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
...
@@ -99,14 +91,12 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
...
@@ -99,14 +91,12 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
@
tilelang
.
jit
@
tilelang
.
jit
def
valid_loop_not_frag
(
dtype
:
str
=
"bfloat16"
,
def
valid_loop_not_frag
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
...
@@ -122,14 +112,12 @@ def valid_loop_not_frag(dtype: str = "bfloat16",
...
@@ -122,14 +112,12 @@ def valid_loop_not_frag(dtype: str = "bfloat16",
@
tilelang
.
jit
@
tilelang
.
jit
def
valid_loop_serial
(
dtype
:
str
=
"bfloat16"
,
def
valid_loop_serial
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
...
...
testing/python/analysis/test_tilelang_nested_loop_checker.py
View file @
29051439
...
@@ -30,11 +30,10 @@ Rule:
...
@@ -30,11 +30,10 @@ Rule:
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_parallels
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
nested_continuous_parallels
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
for
i
in
T
.
Parallel
(
length
//
block
):
...
@@ -46,29 +45,26 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"):
...
@@ -46,29 +45,26 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_triple_continuous_parallels
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
def
nested_triple_continuous_parallels
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block1
//
block2
):
for
i
in
T
.
Parallel
(
length
//
block1
//
block2
):
for
j
in
T
.
Parallel
(
block1
):
for
j
in
T
.
Parallel
(
block1
):
for
k
in
T
.
Parallel
(
block2
):
for
k
in
T
.
Parallel
(
block2
):
B
[
i
*
block1
*
block2
+
j
*
block2
+
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
return
main
return
main
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_noncontinuous_parallels
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
nested_noncontinuous_parallels
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
for
i
in
T
.
Parallel
(
length
//
block
):
...
@@ -103,8 +99,9 @@ is OK.
...
@@ -103,8 +99,9 @@ is OK.
"""
"""
def
matmul_nested_pipelines
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
def
matmul_nested_pipelines
(
out_dtype
,
accum_dtype
,
threads
,
order
,
stage
,
extra_pipeline_repeats
):
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
,
trans_B
,
in_dtype
,
out_dtype
,
accum_dtype
,
threads
,
order
,
stage
,
extra_pipeline_repeats
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
...
@@ -114,9 +111,9 @@ def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B
...
@@ -114,9 +111,9 @@ def matmul_nested_pipelines(M, N, K, block_M, block_N, block_K, trans_A, trans_B
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -180,7 +177,8 @@ def run_gemm_nested_pipelines(
...
@@ -180,7 +177,8 @@ def run_gemm_nested_pipelines(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -193,8 +191,8 @@ def run_gemm_nested_pipelines(
...
@@ -193,8 +191,8 @@ def run_gemm_nested_pipelines(
if
in_dtype
==
"float32"
:
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
return
C
...
@@ -218,11 +216,10 @@ is OK.
...
@@ -218,11 +216,10 @@ is OK.
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_serials
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
nested_continuous_serials
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block
):
for
i
in
T
.
serial
(
length
//
block
):
...
@@ -234,11 +231,10 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"):
...
@@ -234,11 +231,10 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_noncontinuous_serials
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
nested_noncontinuous_serials
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block
):
for
i
in
T
.
serial
(
length
//
block
):
...
@@ -277,11 +273,10 @@ Rule:
...
@@ -277,11 +273,10 @@ Rule:
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_sp
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
nested_continuous_sp
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block
):
for
i
in
T
.
serial
(
length
//
block
):
...
@@ -293,11 +288,10 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"):
...
@@ -293,11 +288,10 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_ps
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
nested_continuous_ps
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
for
i
in
T
.
Parallel
(
length
//
block
):
...
@@ -309,36 +303,32 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"):
...
@@ -309,36 +303,32 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_psp
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
def
nested_continuous_psp
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block1
//
block2
):
for
i
in
T
.
Parallel
(
length
//
block1
//
block2
):
for
j
in
T
.
serial
(
block1
):
for
j
in
T
.
serial
(
block1
):
for
k
in
T
.
Parallel
(
block2
):
for
k
in
T
.
Parallel
(
block2
):
B
[
i
*
block1
*
block2
+
j
*
block2
+
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
return
main
return
main
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
nested_continuous_sps
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
def
nested_continuous_sps
(
length
=
256
,
block1
=
8
,
block2
=
2
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
serial
(
length
//
block1
//
block2
):
for
i
in
T
.
serial
(
length
//
block1
//
block2
):
for
j
in
T
.
Parallel
(
block1
):
for
j
in
T
.
Parallel
(
block1
):
for
k
in
T
.
serial
(
block2
):
for
k
in
T
.
serial
(
block2
):
B
[
i
*
block1
*
block2
+
j
*
block2
+
B
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
k
]
=
A
[
i
*
block1
*
block2
+
j
*
block2
+
k
]
+
1.0
return
main
return
main
...
@@ -399,9 +389,9 @@ def matmul_nested_pipa(
...
@@ -399,9 +389,9 @@ def matmul_nested_pipa(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -444,9 +434,9 @@ def matmul_nested_papipa(
...
@@ -444,9 +434,9 @@ def matmul_nested_papipa(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -505,7 +495,8 @@ def run_gemm_mixed_pp(
...
@@ -505,7 +495,8 @@ def run_gemm_mixed_pp(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -514,8 +505,8 @@ def run_gemm_mixed_pp(
...
@@ -514,8 +505,8 @@ def run_gemm_mixed_pp(
if
in_dtype
==
"float32"
:
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
return
C
...
@@ -543,7 +534,8 @@ def run_gemm_mixed_pp(
...
@@ -543,7 +534,8 @@ def run_gemm_mixed_pp(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
def
test_mixed_pp
():
def
test_mixed_pp
():
...
@@ -576,9 +568,9 @@ def matmul_with_parallel(
...
@@ -576,9 +568,9 @@ def matmul_with_parallel(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
...
@@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel(
...
@@ -637,7 +629,8 @@ def run_gemm_tiled_op_with_parallel(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
profiler
=
kernel
.
get_profiler
()
profiler
=
kernel
.
get_profiler
()
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
...
@@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel(
...
@@ -646,8 +639,8 @@ def run_gemm_tiled_op_with_parallel(
if
in_dtype
==
"float32"
:
if
in_dtype
==
"float32"
:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
# float32 automatically, -0x1000 meas
A
=
(
(
A
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
A
=
(
A
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
B
=
(
(
B
.
view
(
torch
.
int32
)
-
0x1000
)
)
.
view
(
torch
.
float32
)
B
=
(
B
.
view
(
torch
.
int32
)
-
0x1000
).
view
(
torch
.
float32
)
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
C
=
C
.
to
(
torch
.
__getattribute__
(
out_dtype
))
return
C
return
C
...
@@ -675,16 +668,16 @@ def run_gemm_tiled_op_with_parallel(
...
@@ -675,16 +668,16 @@ def run_gemm_tiled_op_with_parallel(
pass_configs
=
{
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
})
},
)
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
tir_op_with_parallel
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
tir_op_with_parallel
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
for
i
in
T
.
Parallel
(
length
//
block
):
...
@@ -696,11 +689,10 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"):
...
@@ -696,11 +689,10 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@
tilelang
.
jit
(
out_idx
=
[
1
])
@
tilelang
.
jit
(
out_idx
=
[
1
])
def
customize_op_with_parallel
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
def
customize_op_with_parallel
(
length
=
256
,
block
=
16
,
dtype
=
"float32"
):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
length
,),
dtype
),
A
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
B
:
T
.
Tensor
((
length
,),
dtype
),
):
):
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
with
T
.
Kernel
(
1
,
threads
=
length
)
as
_
:
for
i
in
T
.
Parallel
(
length
//
block
):
for
i
in
T
.
Parallel
(
length
//
block
):
...
...
testing/python/autotune/test_tilelang_autotune.py
View file @
29051439
...
@@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False):
...
@@ -48,6 +48,7 @@ def get_configs(M, N, K, with_roller=False):
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.template
import
MatmulTemplate
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.arch
import
CUDA
from
tilelang.carver.roller.rasterization
import
NoRasterization
from
tilelang.carver.roller.rasterization
import
NoRasterization
arch
=
CUDA
(
"cuda"
)
arch
=
CUDA
(
"cuda"
)
topk
=
20
topk
=
20
...
@@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False):
...
@@ -84,7 +85,6 @@ def get_configs(M, N, K, with_roller=False):
for
config
in
configs
:
for
config
in
configs
:
print
(
config
)
print
(
config
)
else
:
else
:
block_M
=
[
64
]
block_M
=
[
64
]
block_N
=
[
64
]
block_N
=
[
64
]
block_K
=
[
32
]
block_K
=
[
32
]
...
@@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False):
...
@@ -100,7 +100,8 @@ def get_configs(M, N, K, with_roller=False):
num_stages
,
num_stages
,
thread_num
,
thread_num
,
enable_rasterization
,
enable_rasterization
,
))
)
)
configs
=
[
configs
=
[
{
{
...
@@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False):
...
@@ -110,7 +111,8 @@ def get_configs(M, N, K, with_roller=False):
"num_stages"
:
c
[
3
],
"num_stages"
:
c
[
3
],
"thread_num"
:
c
[
4
],
"thread_num"
:
c
[
4
],
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
"enable_rasteration"
:
c
[
5
],
# keep param name for backward-compat
}
for
c
in
_configs
}
for
c
in
_configs
]
]
return
configs
return
configs
...
@@ -190,9 +192,9 @@ def matmul(M, N, K, with_roller):
...
@@ -190,9 +192,9 @@ def matmul(M, N, K, with_roller):
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
"""
"""
The compiled TVM function for block-level matrix multiplication.
The compiled TVM function for block-level matrix multiplication.
...
@@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller):
...
@@ -206,9 +208,7 @@ def matmul(M, N, K, with_roller):
"""
"""
# Bind x-dimension to block index in N,
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
# y-dimension to block index in M.
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
@@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller):
...
@@ -247,12 +247,16 @@ def matmul(M, N, K, with_roller):
return
main
return
main
autotuner
=
AutoTuner
.
from_kernel
(
autotuner
=
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
)).
set_compile_args
(
AutoTuner
.
from_kernel
(
kernel
=
kernel
,
configs
=
get_configs
(
M
,
N
,
K
,
with_roller
))
.
set_compile_args
(
out_idx
=
[
-
1
],
out_idx
=
[
-
1
],
target
=
"auto"
,
target
=
"auto"
,
).
set_profile_args
(
)
ref_prog
=
ref_program
,)
.
set_profile_args
(
ref_prog
=
ref_program
,
)
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
return
autotuner
.
run
(
warmup
=
3
,
rep
=
20
)
...
...
testing/python/autotune/test_tilelang_autotune_with_inputs.py
View file @
29051439
...
@@ -30,38 +30,23 @@ def ref_program(A, B):
...
@@ -30,38 +30,23 @@ def ref_program(A, B):
def
get_configs
():
def
get_configs
():
iter_params
=
dict
(
iter_params
=
dict
(
block_M
=
[
64
],
block_N
=
[
64
],
block_K
=
[
32
],
num_stages
=
[
0
,
1
],
thread_num
=
[
128
],
enable_rasterization
=
[
False
])
block_M
=
[
64
],
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
block_N
=
[
64
],
block_K
=
[
32
],
num_stages
=
[
0
,
1
],
thread_num
=
[
128
],
enable_rasterization
=
[
False
])
return
[{
k
:
v
for
k
,
v
in
zip
(
iter_params
,
values
)
}
for
values
in
itertools
.
product
(
*
iter_params
.
values
())]
@
tilelang
.
autotune
(
configs
=
get_configs
(),)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
0
,
thread_num
=
128
,
enable_rasterization
=
False
):
@
tilelang
.
autotune
(
configs
=
get_configs
(),
)
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
=
128
,
block_N
=
128
,
block_K
=
32
,
num_stages
=
0
,
thread_num
=
128
,
enable_rasterization
=
False
):
dtype
=
"float16"
dtype
=
"float16"
accum_dtype
=
"float"
accum_dtype
=
"float"
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
B
:
T
.
Tensor
((
N
,
K
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
"""
"""
The compiled TVM function for block-level matrix multiplication.
The compiled TVM function for block-level matrix multiplication.
...
@@ -76,7 +61,6 @@ def matmul(M,
...
@@ -76,7 +61,6 @@ def matmul(M,
# Bind x-dimension to block index in N,
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
# y-dimension to block index in M.
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
thread_num
)
as
(
bx
,
by
):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
...
...
testing/python/cache/test_tilelang_cache_matmul.py
View file @
29051439
...
@@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
...
@@ -28,9 +28,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
...
@@ -63,6 +63,7 @@ def run_cache_matmul():
...
@@ -63,6 +63,7 @@ def run_cache_matmul():
Reference PyTorch matrix multiplication for comparison.
Reference PyTorch matrix multiplication for comparison.
"""
"""
import
torch
import
torch
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
torch
.
matmul
(
A
.
to
(
torch
.
float
),
B
.
to
(
torch
.
float
))
C
=
C
.
to
(
torch
.
half
)
# Assuming dtype="float16" in matmul
C
=
C
.
to
(
torch
.
half
)
# Assuming dtype="float16" in matmul
return
C
return
C
...
...
testing/python/carver/test_tilelang_carver_cuda_driver_properties.py
View file @
29051439
...
@@ -29,9 +29,7 @@ class _cudaDeviceAttrNames:
...
@@ -29,9 +29,7 @@ class _cudaDeviceAttrNames:
def
test_driver_get_device_properties
():
def
test_driver_get_device_properties
():
prop
=
get_cuda_device_properties
()
prop
=
get_cuda_device_properties
()
assert
prop
is
not
None
,
"Failed to get CUDA device properties"
assert
prop
is
not
None
,
"Failed to get CUDA device properties"
assert
isinstance
(
assert
isinstance
(
prop
,
torch
.
cuda
.
_CudaDeviceProperties
),
"Returned object is not of type _CudaDeviceProperties"
prop
,
torch
.
cuda
.
_CudaDeviceProperties
),
(
"Returned object is not of type _CudaDeviceProperties"
)
def
test_device_get_device_name
():
def
test_device_get_device_name
():
...
@@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block():
...
@@ -48,8 +46,7 @@ def test_device_get_shared_memory_per_block():
def
test_device_get_persisting_l2_cache_size
():
def
test_device_get_persisting_l2_cache_size
():
tl_cache_size
=
get_persisting_l2_cache_max_size
()
tl_cache_size
=
get_persisting_l2_cache_max_size
()
driver_cache_size
=
get_device_attribute
(
driver_cache_size
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxPersistingL2CacheSize
)
_cudaDeviceAttrNames
.
cudaDevAttrMaxPersistingL2CacheSize
)
assert
tl_cache_size
==
driver_cache_size
,
"Persisting L2 cache size values do not match"
assert
tl_cache_size
==
driver_cache_size
,
"Persisting L2 cache size values do not match"
...
@@ -61,17 +58,14 @@ def test_device_get_num_sms():
...
@@ -61,17 +58,14 @@ def test_device_get_num_sms():
def
test_device_get_registers_per_block
():
def
test_device_get_registers_per_block
():
tl_regs_per_block
=
get_registers_per_block
()
tl_regs_per_block
=
get_registers_per_block
()
driver_regs_per_block
=
get_device_attribute
(
driver_regs_per_block
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxRegistersPerBlock
)
_cudaDeviceAttrNames
.
cudaDevAttrMaxRegistersPerBlock
)
assert
tl_regs_per_block
==
driver_regs_per_block
,
"Registers per block values do not match"
assert
tl_regs_per_block
==
driver_regs_per_block
,
"Registers per block values do not match"
def
test_device_get_max_dynamic_shared_size_bytes
():
def
test_device_get_max_dynamic_shared_size_bytes
():
tl_dynamic_smem
=
get_max_dynamic_shared_size_bytes
()
tl_dynamic_smem
=
get_max_dynamic_shared_size_bytes
()
driver_dynamic_smem
=
get_device_attribute
(
driver_dynamic_smem
=
get_device_attribute
(
_cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
)
_cudaDeviceAttrNames
.
cudaDevAttrMaxSharedMemoryPerMultiprocessor
)
assert
tl_dynamic_smem
==
driver_dynamic_smem
,
"Max dynamic shared size bytes values do not match"
assert
tl_dynamic_smem
==
driver_dynamic_smem
,
(
"Max dynamic shared size bytes values do not match"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
testing/python/carver/test_tilelang_carver_generate_hints.py
View file @
29051439
...
@@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
...
@@ -9,16 +9,13 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
arch
=
auto_infer_current_arch
()
arch
=
auto_infer_current_arch
()
def
gemm
(
M
,
N
,
K
):
def
gemm
(
M
,
N
,
K
):
A
=
te
.
placeholder
((
M
,
K
),
name
=
'A'
,
dtype
=
'
float16
'
)
A
=
te
.
placeholder
((
M
,
K
),
name
=
"A"
,
dtype
=
"
float16
"
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
'B'
,
dtype
=
'
float16
'
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
"B"
,
dtype
=
"
float16
"
)
# Describe the matrix multiplication in TE
# Describe the matrix multiplication in TE
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
'k'
)
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
"k"
)
C
=
te
.
compute
(
C
=
te
.
compute
((
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
"float16"
)
*
B
[
j
,
k
].
astype
(
"float16"
),
axis
=
[
k
]),
name
=
"C"
)
(
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
'float16'
)
*
B
[
j
,
k
].
astype
(
'float16'
),
axis
=
[
k
]),
name
=
'C'
)
return
A
,
B
,
C
return
A
,
B
,
C
...
@@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
...
@@ -29,8 +26,7 @@ def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
tensorized_func
,
tags
=
carver
.
utils
.
get_tensorized_func_and_tags
(
func
,
arch
.
target
)
tensorized_func
,
tags
=
carver
.
utils
.
get_tensorized_func_and_tags
(
func
,
arch
.
target
)
print
(
tags
)
print
(
tags
)
policy
=
carver
.
TensorCorePolicy
.
from_prim_func
(
policy
=
carver
.
TensorCorePolicy
.
from_prim_func
(
func
=
tensorized_func
,
arch
=
arch
,
tags
=
tags
,
name
=
"matmul_0"
)
func
=
tensorized_func
,
arch
=
arch
,
tags
=
tags
,
name
=
"matmul_0"
)
hints
=
policy
.
emit_config
(
topk
=
topk
)
hints
=
policy
.
emit_config
(
topk
=
topk
)
...
@@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
...
@@ -59,16 +55,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
arch
=
auto_infer_current_arch
()
arch
=
auto_infer_current_arch
()
def
gemm
(
M
,
N
,
K
):
def
gemm
(
M
,
N
,
K
):
A
=
te
.
placeholder
((
M
,
K
),
name
=
'A'
,
dtype
=
'
float16
'
)
A
=
te
.
placeholder
((
M
,
K
),
name
=
"A"
,
dtype
=
"
float16
"
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
'B'
,
dtype
=
'
float16
'
)
B
=
te
.
placeholder
((
N
,
K
),
name
=
"B"
,
dtype
=
"
float16
"
)
# Describe the matrix multiplication in TE
# Describe the matrix multiplication in TE
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
'k'
)
k
=
te
.
reduce_axis
((
0
,
K
),
name
=
"k"
)
C
=
te
.
compute
(
C
=
te
.
compute
((
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
"float16"
)
*
B
[
j
,
k
].
astype
(
"float16"
),
axis
=
[
k
]),
name
=
"C"
)
(
M
,
N
),
lambda
i
,
j
:
te
.
sum
(
A
[
i
,
k
].
astype
(
'float16'
)
*
B
[
j
,
k
].
astype
(
'float16'
),
axis
=
[
k
]),
name
=
'C'
)
return
A
,
B
,
C
return
A
,
B
,
C
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment