Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
146 additions
and
228 deletions
+146
-228
tilelang/primitives/gemm/gemm_mma.py
tilelang/primitives/gemm/gemm_mma.py
+15
-11
tilelang/profiler/__init__.py
tilelang/profiler/__init__.py
+6
-9
tilelang/profiler/bench.py
tilelang/profiler/bench.py
+5
-5
tilelang/quantize/lop3.py
tilelang/quantize/lop3.py
+2
-5
tilelang/quantize/mxfp.py
tilelang/quantize/mxfp.py
+3
-7
tilelang/quantize/utils.py
tilelang/quantize/utils.py
+6
-3
tilelang/testing/__init__.py
tilelang/testing/__init__.py
+9
-10
tilelang/tileop/gemm/__init__.py
tilelang/tileop/gemm/__init__.py
+1
-2
tilelang/tileop/gemm/gemm_mfma.py
tilelang/tileop/gemm/gemm_mfma.py
+6
-12
tilelang/tileop/gemm/gemm_mma.py
tilelang/tileop/gemm/gemm_mma.py
+6
-11
tilelang/tileop/gemm/gemm_mma_sm70.py
tilelang/tileop/gemm/gemm_mma_sm70.py
+6
-11
tilelang/tileop/gemm/gemm_tcgen05.py
tilelang/tileop/gemm/gemm_tcgen05.py
+10
-20
tilelang/tileop/gemm/gemm_wgmma.py
tilelang/tileop/gemm/gemm_wgmma.py
+12
-25
tilelang/tileop/gemm_sp/__init__.py
tilelang/tileop/gemm_sp/__init__.py
+3
-3
tilelang/tileop/gemm_sp/gemm_sp_mma.py
tilelang/tileop/gemm_sp/gemm_sp_mma.py
+4
-10
tilelang/tools/Analyzer.py
tilelang/tools/Analyzer.py
+8
-8
tilelang/tools/plot_layout.py
tilelang/tools/plot_layout.py
+28
-52
tilelang/transform/__init__.py
tilelang/transform/__init__.py
+10
-20
tilelang/transform/add_bufstore_wrapper.py
tilelang/transform/add_bufstore_wrapper.py
+5
-4
tilelang/transform/pass_config.py
tilelang/transform/pass_config.py
+1
-0
No files found.
tilelang/primitives/gemm/gemm_mma.py
View file @
29051439
...
...
@@ -31,7 +31,6 @@ class GemmPrimitiveMMA(GemmBaseParams):
C
:
tir
.
Buffer
,
mma_emitter
:
TensorCoreIntrinEmitter
,
)
->
tir
.
PrimExpr
:
in_dtype
=
self
.
in_dtype
warp_cols
=
mma_emitter
.
warp_cols
local_size_b
=
mma_emitter
.
local_size_b
...
...
@@ -53,21 +52,24 @@ class GemmPrimitiveMMA(GemmBaseParams):
if
a_is_fragment
:
# Annotate layout for A_local if it is a fragment.
T
.
annotate_layout
({
T
.
annotate_layout
(
{
A_local
:
mma_emitter
.
make_mma_load_layout
(
A_local
,
"A"
),
})
}
)
if
c_is_fragment
:
# Annotate layout for C_local if it is a fragment.
T
.
annotate_layout
({
T
.
annotate_layout
(
{
C_local
:
mma_emitter
.
make_mma_store_layout
(
C_local
),
})
}
)
# Make default swizzle layout for shared memory
# T.annotate_layout({
# B_shared: make_mma_swizzle_layout(B_shared),
# })
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load B into fragment
mma_emitter
.
ldmatrix_b
(
B_local
,
...
...
@@ -146,9 +148,11 @@ class GemmPrimitiveMMA(GemmBaseParams):
if
c_is_fragment
:
# Annotate layout for C_local if it is a fragment.
T
.
annotate_layout
({
T
.
annotate_layout
(
{
C_local
:
mma_emitter
.
make_mma_store_layout
(
C_local
),
})
}
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load A into fragment
...
...
tilelang/profiler/__init__.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
typing
import
Callable
,
Any
,
Literal
from
functools
import
partial
...
...
@@ -45,8 +46,7 @@ class Profiler:
result_idx
=
[]
elif
isinstance
(
result_idx
,
int
):
if
result_idx
>
len
(
params
)
or
result_idx
<
-
len
(
params
):
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
}
and
{
len
(
params
)
-
1
}
"
)
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
}
and
{
len
(
params
)
-
1
}
"
)
if
result_idx
<
0
:
result_idx
=
len
(
params
)
+
result_idx
result_idx
=
[
result_idx
]
...
...
@@ -113,8 +113,7 @@ class Profiler:
ref_tensors
=
ins
+
ref_outs
lib_tensors
=
ins
+
lib_outs
assert
len
(
lib_tensors
)
==
len
(
ref_tensors
),
"len(lib_tensors) not equals to len(ref_tensors) !"
assert
len
(
lib_tensors
)
==
len
(
ref_tensors
),
"len(lib_tensors) not equals to len(ref_tensors) !"
# torch.set_printoptions(edgeitems=torch.inf)
for
lhs
,
rhs
in
zip
(
lib_tensors
,
ref_tensors
):
# close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
...
...
@@ -252,10 +251,9 @@ class Profiler:
)
elif
profiler
==
"tvm"
:
assert
func
is
not
None
,
"func should not be None"
assert
isinstance
(
func
,
tvm
.
runtime
.
Module
),
f
"func should be a TVM module, but got
{
type
(
func
)
}
"
assert
isinstance
(
func
,
tvm
.
runtime
.
Module
),
f
"func should be a TVM module, but got
{
type
(
func
)
}
"
ins
=
(
self
.
_get_inputs
(
with_output
=
True
)
if
input_tensors
is
None
else
input_tensors
)
ins
=
self
.
_get_inputs
(
with_output
=
True
)
if
input_tensors
is
None
else
input_tensors
target
=
"cuda"
with
suppress
(
Exception
):
...
...
@@ -264,8 +262,7 @@ class Profiler:
assert
target
in
[
"cuda"
,
"hip"
],
f
"Unknown target:
{
target
}
"
device
=
tvm
.
cuda
(
0
)
if
target
==
"cuda"
else
tvm
.
rocm
(
0
)
time_evaluator
=
self
.
mod
.
time_evaluator
(
self
.
mod
.
entry_name
,
device
,
number
=
rep
,
repeat
=
n_repeat
)
time_evaluator
=
self
.
mod
.
time_evaluator
(
self
.
mod
.
entry_name
,
device
,
number
=
rep
,
repeat
=
n_repeat
)
# Transform Latency to ms
return
time_evaluator
(
*
ins
).
mean
*
1e3
else
:
...
...
tilelang/profiler/bench.py
View file @
29051439
"""Profiler and benchmarking utilities for PyTorch functions."""
from
__future__
import
annotations
import
os
...
...
@@ -16,8 +17,8 @@ class suppress_stdout_stderr:
def
__enter__
(
self
):
# Open null device files
self
.
outnull_file
=
open
(
os
.
devnull
,
'w'
)
self
.
errnull_file
=
open
(
os
.
devnull
,
'w'
)
self
.
outnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
errnull_file
=
open
(
os
.
devnull
,
"w"
)
# Save original file descriptors
self
.
old_stdout_fileno_undup
=
sys
.
stdout
.
fileno
()
...
...
@@ -56,7 +57,7 @@ class suppress_stdout_stderr:
IS_CUDA
=
torch
.
cuda
.
is_available
()
device
=
'
cuda:0
'
if
IS_CUDA
else
'
mps:0
'
device
=
"
cuda:0
"
if
IS_CUDA
else
"
mps:0
"
Event
=
torch
.
cuda
.
Event
if
IS_CUDA
else
torch
.
mps
.
Event
...
...
@@ -93,8 +94,7 @@ def do_bench(
Returns:
Runtime in milliseconds (float) or list of quantile values if quantiles specified
"""
assert
return_mode
in
[
"min"
,
"max"
,
"mean"
,
"median"
],
\
f
"Invalid return_mode:
{
return_mode
}
"
assert
return_mode
in
[
"min"
,
"max"
,
"mean"
,
"median"
],
f
"Invalid return_mode:
{
return_mode
}
"
# Initial function call and synchronization
fn
()
...
...
tilelang/quantize/lop3.py
View file @
29051439
...
...
@@ -1130,16 +1130,13 @@ def get_lop3_intrin_group(
Dict[str, str]
A dictionary mapping the names of the intrinsics to their corresponding implementations.
"""
assert
out_dtype
in
[
"float16"
,
"int8"
,
"int4"
],
(
f
"Invalid out_dtype:
{
out_dtype
}
. Expected 'float16' or 'int8' or 'int4' ."
)
assert
out_dtype
in
[
"float16"
,
"int8"
,
"int4"
],
f
"Invalid out_dtype:
{
out_dtype
}
. Expected 'float16' or 'int8' or 'int4' ."
dtype_mapping
=
{
"float16"
:
"f16"
,
"int4"
:
"i4"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}
target_dtype
=
dtype_mapping
[
out_dtype
]
if
source_format
not
in
[
"int"
,
"uint"
]:
raise
ValueError
(
f
"Invalid source_format. Expected 'int' or 'uint', but got
{
source_format
}
."
)
raise
ValueError
(
f
"Invalid source_format. Expected 'int' or 'uint', but got
{
source_format
}
."
)
if
with_zeros
and
source_format
==
"int"
:
raise
ValueError
(
f
"Zeros are not supported for signed integers, but got
{
source_format
}
"
)
...
...
tilelang/quantize/mxfp.py
View file @
29051439
...
...
@@ -80,13 +80,9 @@ def get_mxfp_intrin_group(
AssertionError: if out_dtype, source_format, or storage_dtype are not supported.
KeyError: if the constructed key does not match any available C source implementation.
"""
assert
out_dtype
in
[
"float16"
,
"bfloat16"
],
f
"Invalid out_dtype:
{
out_dtype
}
. Expected 'float16' or 'bfloat16'."
assert
source_format
in
[
"int"
,
"uint"
],
f
"Invalid source_format:
{
source_format
}
. Expected 'int' or 'uint'."
assert
storage_dtype
in
[
"int32"
,
"int8"
,
"uint8"
],
f
"Invalid storage_dtype:
{
storage_dtype
}
. Expected 'int32' or 'int8' or 'uint8'."
assert
out_dtype
in
[
"float16"
,
"bfloat16"
],
f
"Invalid out_dtype:
{
out_dtype
}
. Expected 'float16' or 'bfloat16'."
assert
source_format
in
[
"int"
,
"uint"
],
f
"Invalid source_format:
{
source_format
}
. Expected 'int' or 'uint'."
assert
storage_dtype
in
[
"int32"
,
"int8"
,
"uint8"
],
f
"Invalid storage_dtype:
{
storage_dtype
}
. Expected 'int32' or 'int8' or 'uint8'."
dtype_map
=
{
"float16"
:
"f16"
,
"bfloat16"
:
"bf16"
}
key
=
f
"fp
{
source_bit
}
_to_
{
dtype_map
[
out_dtype
]
}
"
...
...
tilelang/quantize/utils.py
View file @
29051439
def
gen_quant4
(
k
,
n
,
groupsize
=-
1
):
import
torch
import
torch.nn
as
nn
maxq
=
2
**
4
w
=
torch
.
randn
((
k
,
n
),
dtype
=
torch
.
half
,
device
=
"cpu"
)
...
...
@@ -48,6 +49,7 @@ def gen_quant4(k, n, groupsize=-1):
def
general_compress
(
lowprecision_weight
,
source_bits
=
4
,
storage_dtype
=
None
):
import
torch
if
storage_dtype
is
None
:
storage_dtype
=
torch
.
int8
elems_per_byte
=
8
//
source_bits
...
...
@@ -56,11 +58,11 @@ def general_compress(lowprecision_weight, source_bits=4, storage_dtype=None):
int8_weight
=
torch
.
zeros
(
(
*
lowprecision_weight
.
shape
[:
-
1
],
lowprecision_weight
.
shape
[
-
1
]
//
elems_per_byte
),
dtype
=
torch
.
int8
,
device
=
lowprecision_weight
.
device
)
device
=
lowprecision_weight
.
device
,
)
for
j
in
range
(
lowprecision_weight
.
shape
[
-
1
]
//
elems_per_byte
):
for
k
in
range
(
elems_per_byte
):
int8_weight
[...,
j
]
|=
(
lowprecision_weight
[...,
j
*
elems_per_byte
+
k
]
<<
(
source_bits
*
k
)).
to
(
torch
.
int8
)
int8_weight
[...,
j
]
|=
(
lowprecision_weight
[...,
j
*
elems_per_byte
+
k
]
<<
(
source_bits
*
k
)).
to
(
torch
.
int8
)
return
int8_weight
.
to
(
storage_dtype
)
...
...
@@ -82,6 +84,7 @@ def interleave_weight(qweight, nbits=4, target_dtype="float16"):
interleave_weight(qweight, 4, "float16")
"""
import
torch
assert
target_dtype
in
[
"float16"
,
"int8"
]
# reinterpret the data type of qweight to int32
qweight
=
qweight
.
view
(
torch
.
int32
)
...
...
tilelang/testing/__init__.py
View file @
29051439
...
...
@@ -5,20 +5,19 @@ import random
import
torch
import
numpy
as
np
from
tilelang.contrib
import
nvcc
from
tvm.testing.utils
import
(
requires_cuda
,
requires_package
,
requires_llvm
,
requires_metal
,
requires_rocm
,
_compose
)
from
tvm.testing.utils
import
requires_cuda
,
requires_package
,
requires_llvm
,
requires_metal
,
requires_rocm
,
_compose
from
tilelang.utils.tensor
import
torch_assert_close
as
torch_assert_close
__all__
=
[
'
requires_package
'
,
'
requires_cuda
'
,
'
requires_metal
'
,
'
requires_rocm
'
,
'
requires_llvm
'
,
'
main
'
,
'
requires_cuda_compute_version
'
,
]
+
[
f
'
requires_cuda_compute_version_
{
op
}
'
for
op
in
(
'
ge
'
,
'
gt
'
,
'
le
'
,
'
lt
'
,
'
eq
'
)]
"
requires_package
"
,
"
requires_cuda
"
,
"
requires_metal
"
,
"
requires_rocm
"
,
"
requires_llvm
"
,
"
main
"
,
"
requires_cuda_compute_version
"
,
]
+
[
f
"
requires_cuda_compute_version_
{
op
}
"
for
op
in
(
"
ge
"
,
"
gt
"
,
"
le
"
,
"
lt
"
,
"
eq
"
)]
# pytest.main() wrapper to allow running single test file
...
...
tilelang/tileop/gemm/__init__.py
View file @
29051439
...
...
@@ -23,8 +23,7 @@ def gemm_py_infer_layout(gemm_py: GemmMMA, target: Target, thread_bounds: Range)
@
tvm_ffi
.
register_global_func
(
"tl.gemm_py.lower"
)
def
gemm_py_lower
(
gemm_py
:
GemmMMA
,
layout_map
,
target
:
Target
,
thread_bounds
:
Range
,
thread_var
:
tir
.
Var
):
def
gemm_py_lower
(
gemm_py
:
GemmMMA
,
layout_map
,
target
:
Target
,
thread_bounds
:
Range
,
thread_var
:
tir
.
Var
):
thread_nums
=
thread_bounds
.
extent
stmt
=
gemm_py
.
lower
(
layout_map
,
target
,
thread_nums
,
thread_var
)
return
stmt
...
...
tilelang/tileop/gemm/gemm_mfma.py
View file @
29051439
from
.gemm_base
import
GemmBase
from
tilelang.layout
import
make_swizzled_layout
from
tilelang.intrinsics.mfma_macro_generator
import
(
MatrixCoreIntrinEmitter
,)
MatrixCoreIntrinEmitter
,
)
from
tilelang.utils.language
import
is_shared
,
is_fragment
,
is_full_region
from
tilelang
import
tvm
as
tvm
from
tvm.target
import
Target
...
...
@@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify
class
GemmMFMA
(
GemmBase
):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mfma_emitter
=
MatrixCoreIntrinEmitter
(
...
...
@@ -56,12 +55,10 @@ class GemmMFMA(GemmBase):
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
lower
(
self
,
layout_map
:
dict
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mfma_emitter
=
MatrixCoreIntrinEmitter
(
...
...
@@ -153,7 +150,6 @@ class GemmMFMA(GemmBase):
T
.
clear
(
C_buf
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
micro_size_k
*
self
.
k_pack
))):
# Load A into fragment
mfma_emitter
.
ldmatrix_a
(
A_local
,
...
...
@@ -183,7 +179,6 @@ class GemmMFMA(GemmBase):
if
clear_accum
:
T
.
clear
(
C_buf
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
(
micro_size_k
*
self
.
k_pack
))):
# Load B into fragment
mfma_emitter
.
ldmatrix_b
(
B_local
,
...
...
@@ -217,8 +212,7 @@ class GemmMFMA(GemmBase):
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_rsr
,
inline_let
=
True
)
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
is_gemm_ss
(
self
)
->
bool
:
return
is_shared
(
self
.
A
)
and
is_shared
(
self
.
B
)
...
...
tilelang/tileop/gemm/gemm_mma.py
View file @
29051439
from
.gemm_base
import
GemmBase
from
tilelang.layout
import
make_swizzled_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.utils.language
import
is_shared
,
is_fragment
,
is_full_region
from
tilelang
import
tvm
as
tvm
from
tvm.target
import
Target
...
...
@@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify
class
GemmMMA
(
GemmBase
):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
TensorCoreIntrinEmitter
(
...
...
@@ -54,12 +53,10 @@ class GemmMMA(GemmBase):
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
}
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
lower
(
self
,
layout_map
:
dict
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
TensorCoreIntrinEmitter
(
...
...
@@ -177,7 +174,6 @@ class GemmMMA(GemmBase):
if
clear_accum
:
T
.
clear
(
C_buf
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load B into fragment
mma_emitter
.
ldmatrix_b
(
B_local
,
...
...
@@ -211,8 +207,7 @@ class GemmMMA(GemmBase):
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_rrr
,
inline_let
=
True
)
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
is_gemm_ss
(
self
)
->
bool
:
return
is_shared
(
self
.
A
)
and
is_shared
(
self
.
B
)
...
...
tilelang/tileop/gemm/gemm_mma_sm70.py
View file @
29051439
...
...
@@ -2,7 +2,8 @@
from
.gemm_base
import
GemmBase
from
tilelang.layout
import
make_volta_swizzled_layout
from
tilelang.intrinsics.mma_sm70_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.utils.language
import
is_shared
,
is_fragment
,
is_full_region
from
tilelang
import
tvm
as
tvm
from
tvm.target
import
Target
...
...
@@ -12,10 +13,8 @@ from tilelang.transform.simplify import _Simplify
class
GemmMMASm70
(
GemmBase
):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
TensorCoreIntrinEmitter
(
...
...
@@ -45,12 +44,10 @@ class GemmMMASm70(GemmBase):
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
}
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
lower
(
self
,
layout_map
:
dict
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
TensorCoreIntrinEmitter
(
...
...
@@ -140,7 +137,6 @@ class GemmMMASm70(GemmBase):
T
.
clear
(
C_buf
)
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load B into fragment
mma_emitter
.
ldmatrix_b
(
B_local
,
...
...
@@ -155,8 +151,7 @@ class GemmMMASm70(GemmBase):
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_rsr
,
inline_let
=
True
)
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
is_gemm_ss
(
self
)
->
bool
:
return
is_shared
(
self
.
A
)
and
is_shared
(
self
.
B
)
...
...
tilelang/tileop/gemm/gemm_tcgen05.py
View file @
29051439
from
.gemm_base
import
GemmBase
from
tilelang.layout
import
make_tcgen05mma_swizzled_layout
from
tilelang.intrinsics.tcgen05_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang
import
language
as
T
from
tilelang.transform.simplify
import
_Simplify
from
tvm
import
tir
...
...
@@ -18,10 +19,8 @@ _FLOAT8_DTYPES = {
class
GemmTCGEN5
(
GemmBase
):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
TensorCoreIntrinEmitter
(
...
...
@@ -40,27 +39,20 @@ class GemmTCGEN5(GemmBase):
b_is_k_major
=
self
.
trans_B
if
self
.
is_gemm_ss
():
a_continuity
=
self
.
M
if
a_is_k_major
else
4
*
self
.
K
//
m_warp
b_continuity
=
self
.
K
if
b_is_k_major
else
self
.
N
//
n_warp
return
{
# WGMMA does not support padding
self
.
A
:
make_tcgen05mma_swizzled_layout
(
self
.
A
,
continuity
=
a_continuity
,
k_major
=
a_is_k_major
),
self
.
B
:
make_tcgen05mma_swizzled_layout
(
self
.
B
,
continuity
=
b_continuity
,
k_major
=
b_is_k_major
),
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
self
.
A
:
make_tcgen05mma_swizzled_layout
(
self
.
A
,
continuity
=
a_continuity
,
k_major
=
a_is_k_major
),
self
.
B
:
make_tcgen05mma_swizzled_layout
(
self
.
B
,
continuity
=
b_continuity
,
k_major
=
b_is_k_major
),
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
}
# No special swizzle requirement; rely on existing layout.
return
{}
def
lower
(
self
,
layout_map
:
dict
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
TensorCoreIntrinEmitter
(
...
...
@@ -82,11 +74,9 @@ class GemmTCGEN5(GemmBase):
mma_emitter
.
_assign_b_shared_layout
(
layout_map
[
self
.
B
])
if
not
self
.
is_gemm_ss
():
raise
ValueError
(
f
"TCGEN5MMA currently only supports gemm_ss, got "
f
"A scope
{
self
.
A
.
scope
()
}
, B scope
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"TCGEN5MMA currently only supports gemm_ss, got A scope
{
self
.
A
.
scope
()
}
, B scope
{
self
.
B
.
scope
()
}
"
)
atom_m
,
atom_n
,
atom_k
,
enable_ws
,
enable_2cta
=
mma_emitter
.
get_tcgen5_mma_meta
(
self
.
M
,
self
.
N
,
self
.
K
)
atom_m
,
atom_n
,
atom_k
,
enable_ws
,
enable_2cta
=
mma_emitter
.
get_tcgen5_mma_meta
(
self
.
M
,
self
.
N
,
self
.
K
)
if
self
.
A
.
scope
()
not
in
{
"shared"
,
"shared.dyn"
,
"shared.tmem"
}:
raise
ValueError
(
f
"Unsupported A scope for TCGEN5MMA:
{
self
.
A
.
scope
()
}
"
)
...
...
@@ -108,7 +98,7 @@ class GemmTCGEN5(GemmBase):
raise
ValueError
(
"TCGEN5MMA expects 2D coordinates for C buffer access"
)
accum_dtype
=
str
(
self
.
C
.
dtype
)
if
accum_dtype
not
in
[
"float32"
,
'
float16
'
]:
if
accum_dtype
not
in
[
"float32"
,
"
float16
"
]:
raise
ValueError
(
f
"Unsupported accumulator dtype for TCGEN5MMA:
{
accum_dtype
}
"
)
A_shared
=
self
.
ARegion
...
...
tilelang/tileop/gemm/gemm_wgmma.py
View file @
29051439
from
.gemm_base
import
GemmBase
from
tilelang.layout
import
make_wgmma_swizzled_layout
from
tilelang.intrinsics.wgmma_macro_generator
import
(
TensorCoreIntrinEmitter
,)
TensorCoreIntrinEmitter
,
)
from
tilelang.utils.language
import
is_shared
,
is_fragment
from
tilelang
import
tvm
as
tvm
from
tvm.target
import
Target
...
...
@@ -11,10 +12,8 @@ from tilelang.transform.simplify import _Simplify
class
GemmWGMMA
(
GemmBase
):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
TensorCoreIntrinEmitter
(
...
...
@@ -38,33 +37,22 @@ class GemmWGMMA(GemmBase):
return
{
# WGMMA does not support padding
self
.
A
:
make_wgmma_swizzled_layout
(
self
.
A
,
continuity
=
a_continuity
,
k_major
=
a_is_k_major
),
self
.
B
:
make_wgmma_swizzled_layout
(
self
.
B
,
continuity
=
b_continuity
,
k_major
=
b_is_k_major
),
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
self
.
A
:
make_wgmma_swizzled_layout
(
self
.
A
,
continuity
=
a_continuity
,
k_major
=
a_is_k_major
),
self
.
B
:
make_wgmma_swizzled_layout
(
self
.
B
,
continuity
=
b_continuity
,
k_major
=
b_is_k_major
),
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
}
elif
self
.
is_gemm_rs
():
b_continuity
=
self
.
N
if
b_is_k_major
else
4
*
self
.
K
//
n_warp
return
{
self
.
A
:
mma_emitter
.
make_mma_load_layout
(
self
.
A
,
matrix
=
"A"
),
self
.
B
:
make_wgmma_swizzled_layout
(
self
.
B
,
continuity
=
b_continuity
,
k_major
=
b_is_k_major
),
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
self
.
A
:
mma_emitter
.
make_mma_load_layout
(
self
.
A
,
matrix
=
"A"
),
self
.
B
:
make_wgmma_swizzled_layout
(
self
.
B
,
continuity
=
b_continuity
,
k_major
=
b_is_k_major
),
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
}
else
:
raise
ValueError
(
f
"Unsupported gemm combination for wgmma, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination for wgmma, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
lower
(
self
,
layout_map
:
dict
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
True
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
...
...
@@ -133,8 +121,7 @@ class GemmWGMMA(GemmBase):
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_rsr
,
inline_let
=
True
)
raise
ValueError
(
f
"Unsupported gemm combination for wgmma, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination for wgmma, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
is_gemm_ss
(
self
)
->
bool
:
return
is_shared
(
self
.
A
)
and
is_shared
(
self
.
B
)
...
...
tilelang/tileop/gemm_sp/__init__.py
View file @
29051439
from
tilelang
import
tvm
as
tvm
from
tvm
import
tir
from
tilelang.utils.target
import
(
target_is_cuda
,)
target_is_cuda
,
)
from
tvm.target
import
Target
from
tvm.ir.base
import
Node
from
tvm.ir
import
Range
...
...
@@ -18,8 +19,7 @@ def gemm_sp_py_infer_layout(gemm_sp_py: GemmSPMMA, target: Target, thread_bounds
@
tvm_ffi
.
register_global_func
(
"tl.gemm_sp_py.lower"
)
def
gemm_sp_py_lower
(
gemm_sp_py
:
GemmSPMMA
,
target
:
Target
,
thread_bounds
:
Range
,
thread_var
:
tir
.
Var
):
def
gemm_sp_py_lower
(
gemm_sp_py
:
GemmSPMMA
,
target
:
Target
,
thread_bounds
:
Range
,
thread_var
:
tir
.
Var
):
thread_nums
=
thread_bounds
.
extent
stmt
=
gemm_sp_py
.
lower
(
target
,
thread_nums
,
thread_var
)
return
stmt
...
...
tilelang/tileop/gemm_sp/gemm_sp_mma.py
View file @
29051439
...
...
@@ -10,10 +10,8 @@ from tilelang.transform.simplify import _Simplify
class
GemmSPMMA
(
GemmSPBase
):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
SparseTensorCoreIntrinEmitter
(
...
...
@@ -55,12 +53,10 @@ class GemmSPMMA(GemmSPBase):
self
.
C
:
mma_emitter
.
make_mma_store_layout
(
self
.
C
),
}
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
lower
(
self
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
m_warp
,
n_warp
=
self
.
policy
.
compute_warp_partition
(
self
.
M
,
self
.
N
,
thread_nums
,
target
,
False
)
warp_row_tiles
=
int
(
self
.
M
//
m_warp
)
warp_col_tiles
=
int
(
self
.
N
//
n_warp
)
mma_emitter
=
SparseTensorCoreIntrinEmitter
(
...
...
@@ -146,7 +142,6 @@ class GemmSPMMA(GemmSPBase):
E_local
=
T
.
alloc_local
((
warp_rows
*
local_size_e
),
self
.
e_dtype
)
for
ki
in
T
.
serial
(
0
,
(
self
.
K
//
micro_size_k
)):
# Load A into fragment
mma_emitter
.
ldmatrix_a
(
A_local
,
...
...
@@ -231,8 +226,7 @@ class GemmSPMMA(GemmSPBase):
# Must inline let statements to simplify the analysis
return
_Simplify
(
_gemm_rrr
,
inline_let
=
True
)
else
:
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
raise
ValueError
(
f
"Unsupported gemm combination, A:
{
self
.
A
.
scope
()
}
, B:
{
self
.
B
.
scope
()
}
"
)
def
is_gemm_ss
(
self
)
->
bool
:
return
is_shared
(
self
.
A
)
and
is_shared
(
self
.
B
)
...
...
tilelang/tools/Analyzer.py
View file @
29051439
...
...
@@ -4,6 +4,7 @@ from dataclasses import dataclass
from
tilelang
import
tvm
from
tvm.tir.stmt_functor
import
ir_transform
import
logging
# Configuration for different hardware architectures.
# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count)
ARCH_CONFIGS
=
{
"80"
:
(
128
,
1.41
,
2
,
108
),
"86"
:
(
128
,
1.70
,
2
,
84
),
"89"
:
(
128
,
2.52
,
2
,
128
)}
...
...
@@ -23,6 +24,7 @@ class AnalysisResult:
tflops: Achieved TFLOPS (trillions of FLOPs per second).
bandwidth_GBps: Achieved memory bandwidth in GB/s.
"""
total_flops
:
int
total_global_bytes
:
int
estimated_time
:
float
...
...
@@ -81,7 +83,7 @@ class Analyzer:
# Account for loop and block dimensions
loop_product
=
1
for
extent
in
self
.
loop_stack
:
loop_product
*=
extent
.
value
if
hasattr
(
extent
,
'
value
'
)
else
extent
loop_product
*=
extent
.
value
if
hasattr
(
extent
,
"
value
"
)
else
extent
total_blocks
=
self
.
block_counts
[
"blockIdx.x"
]
*
self
.
block_counts
[
"blockIdx.y"
]
total_bytes
=
bytes_transferred
*
loop_product
*
total_blocks
self
.
total_global_bytes
+=
total_bytes
...
...
@@ -100,7 +102,7 @@ class Analyzer:
# Account for loop and block dimensions
loop_product
=
1
for
extent
in
self
.
loop_stack
:
loop_product
*=
extent
.
value
if
hasattr
(
extent
,
'
value
'
)
else
extent
loop_product
*=
extent
.
value
if
hasattr
(
extent
,
"
value
"
)
else
extent
total_blocks
=
self
.
block_counts
[
"blockIdx.x"
]
*
self
.
block_counts
[
"blockIdx.y"
]
self
.
total_flops
+=
flops_per_call
*
loop_product
*
total_blocks
...
...
@@ -127,8 +129,7 @@ class Analyzer:
iter_var
=
stmt
.
node
thread_tag
=
iter_var
.
thread_tag
if
thread_tag
in
self
.
block_counts
:
extent
=
stmt
.
value
.
value
if
hasattr
(
stmt
.
value
,
'value'
)
else
stmt
.
value
extent
=
stmt
.
value
.
value
if
hasattr
(
stmt
.
value
,
"value"
)
else
stmt
.
value
self
.
block_counts
[
thread_tag
]
=
extent
elif
isinstance
(
stmt
,
tvm
.
tir
.
For
):
# Push loop extent onto the stack
...
...
@@ -178,9 +179,7 @@ class Analyzer:
"""
arch_key
=
device
.
compute_capability
[:
2
]
if
arch_key
not
in
ARCH_CONFIGS
:
logger
.
info
(
f
"Unsupported compute capability:
{
device
.
compute_capability
}
, theoretical peak tflops will be None"
)
logger
.
info
(
f
"Unsupported compute capability:
{
device
.
compute_capability
}
, theoretical peak tflops will be None"
)
return
None
cores_per_sm
,
default_clock
,
flops_per_cycle
,
compute_max_core
=
ARCH_CONFIGS
[
arch_key
]
...
...
@@ -203,7 +202,8 @@ class Analyzer:
total_global_bytes
=
self
.
total_global_bytes
,
estimated_time
=
estimated_time
,
expected_tflops
=
peak_tflops
,
expected_bandwidth_GBps
=
bandwidth_GBps
)
expected_bandwidth_GBps
=
bandwidth_GBps
,
)
@
classmethod
def
analysis
(
cls
,
fn
,
device
):
...
...
tilelang/tools/plot_layout.py
View file @
29051439
...
...
@@ -2,12 +2,14 @@ from __future__ import annotations
import
tilelang.language
as
T
def
plot_layout
(
layout
:
T
.
Fragment
,
def
plot_layout
(
layout
:
T
.
Fragment
,
save_directory
=
"./tmp"
,
name
:
str
=
"layout"
,
colormap
:
str
=
"RdPu"
,
verbose
:
bool
=
False
,
formats
:
str
|
list
[
str
]
=
"png"
)
->
None
:
formats
:
str
|
list
[
str
]
=
"png"
,
)
->
None
:
"""
Plot the layout of a buffer.
...
...
@@ -90,11 +92,13 @@ def plot_layout(layout: T.Fragment,
# Warn if the number of threads is less than the warp size
if
num_threads
<
warp_size
:
import
warnings
warnings
.
warn
(
f
"Layout visualization has
{
num_threads
}
threads, which is less than the warp size (
{
warp_size
}
). "
f
"For the best viewing experience, it is recommended to have at least
{
warp_size
}
threads."
,
UserWarning
,
stacklevel
=
2
)
stacklevel
=
2
,
)
spectral_camp
=
plt
.
get_cmap
(
"hsv"
,
warp_size
*
6
)
for
i
in
range
(
min
(
warp_size
,
num_threads
)):
...
...
@@ -118,12 +122,7 @@ def plot_layout(layout: T.Fragment,
color
=
colors
[
thread_ids
[
0
]]
# Select color based on thread ID
# Create a rectangle patch for visualization
rect
=
patches
.
Rectangle
((
j
,
i
),
1
,
1
,
linewidth
=
0.5
,
edgecolor
=
'black'
,
facecolor
=
color
)
rect
=
patches
.
Rectangle
((
j
,
i
),
1
,
1
,
linewidth
=
0.5
,
edgecolor
=
"black"
,
facecolor
=
color
)
ax
.
add_patch
(
rect
)
# Add the rectangle to the plot
# Add text annotations inside the rectangles
...
...
@@ -139,41 +138,19 @@ def plot_layout(layout: T.Fragment,
thread_fontsize
=
min
(
font_size
,
font_size
*
(
4
/
len
(
thread_str
)))
# Add thread ID text with adjusted font size
ax
.
text
(
j
+
0.5
,
i
+
0.3
,
thread_str
,
ha
=
'center'
,
va
=
'center'
,
color
=
'black'
,
fontsize
=
thread_fontsize
)
ax
.
text
(
j
+
0.5
,
i
+
0.3
,
thread_str
,
ha
=
"center"
,
va
=
"center"
,
color
=
"black"
,
fontsize
=
thread_fontsize
)
# Add local ID text with original font size
ax
.
text
(
j
+
0.5
,
i
+
0.7
,
f
"L
{
local_id
}
"
,
ha
=
'center'
,
va
=
'center'
,
color
=
'black'
,
fontsize
=
font_size
)
ax
.
text
(
j
+
0.5
,
i
+
0.7
,
f
"L
{
local_id
}
"
,
ha
=
"center"
,
va
=
"center"
,
color
=
"black"
,
fontsize
=
font_size
)
# Add row labels to the left side of the plot
for
i
in
range
(
nrows
):
text
=
f
"row
{
i
}
"
ax
.
text
(
-
0.75
,
i
+
0.5
,
text
,
ha
=
'
center
'
,
va
=
'
center
'
,
color
=
'
black
'
,
fontsize
=
font_size
)
ax
.
text
(
-
0.75
,
i
+
0.5
,
text
,
ha
=
"
center
"
,
va
=
"
center
"
,
color
=
"
black
"
,
fontsize
=
font_size
)
# Add column labels at the top of the plot
for
j
in
range
(
ncols
):
text
=
f
"col
{
j
}
"
ax
.
text
(
j
+
0.5
,
-
0.5
,
text
,
ha
=
'center'
,
va
=
'center'
,
color
=
'black'
,
fontsize
=
font_size
,
rotation
=
45
)
ax
.
text
(
j
+
0.5
,
-
0.5
,
text
,
ha
=
"center"
,
va
=
"center"
,
color
=
"black"
,
fontsize
=
font_size
,
rotation
=
45
)
# Set the plot limits
ax
.
set_xlim
(
0
,
ncols
)
...
...
@@ -189,17 +166,15 @@ def plot_layout(layout: T.Fragment,
legend_x
=
1.0
+
(
0.5
/
fig_width
)
# Adjust x position based on figure width
legend_y
=
1.0
+
(
1.7
/
fig_height
)
# Adjust y position based on figure height
legend_patches
=
[
patches
.
Patch
(
color
=
'black'
,
label
=
"T: Thread ID"
),
patches
.
Patch
(
color
=
'black'
,
label
=
"L: Local ID"
)
]
legend_patches
=
[
patches
.
Patch
(
color
=
"black"
,
label
=
"T: Thread ID"
),
patches
.
Patch
(
color
=
"black"
,
label
=
"L: Local ID"
)]
ax
.
legend
(
handles
=
legend_patches
,
loc
=
"upper right"
,
fontsize
=
font_size
-
4
,
frameon
=
False
,
bbox_to_anchor
=
(
legend_x
,
legend_y
),
# Dynamic position
ncols
=
2
)
ncols
=
2
,
)
# Create the output directory if it does not exist
tmp_directory
=
pathlib
.
Path
(
save_directory
)
...
...
@@ -211,28 +186,29 @@ def plot_layout(layout: T.Fragment,
if
isinstance
(
formats
,
str
):
formats_str
=
formats
.
strip
().
lower
()
if
formats_str
==
'
all
'
:
formats_list
=
[
'
pdf
'
,
'
png
'
,
'
svg
'
]
if
formats_str
==
"
all
"
:
formats_list
=
[
"
pdf
"
,
"
png
"
,
"
svg
"
]
elif
","
in
formats_str
:
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
','
)]
formats_list
=
[
f
.
strip
()
for
f
in
formats_str
.
split
(
","
)]
else
:
formats_list
=
[
formats_str
]
else
:
raise
TypeError
(
f
"Expected str, but got
{
type
(
formats
).
__name__
}
. "
f
"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'."
)
raise
TypeError
(
f
"Expected str, but got
{
type
(
formats
).
__name__
}
. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'."
)
# Save the figure
if
'
pdf
'
in
formats_list
:
if
"
pdf
"
in
formats_list
:
pdf_path
=
tmp_directory
/
f
"
{
name
}
.pdf"
plt
.
savefig
(
pdf_path
,
bbox_inches
=
"tight"
)
print
(
f
"Saved pdf format into
{
pdf_path
}
"
)
if
'
png
'
in
formats_list
:
if
"
png
"
in
formats_list
:
png_path
=
tmp_directory
/
f
"
{
name
}
.png"
plt
.
savefig
(
png_path
,
bbox_inches
=
"tight"
,
transparent
=
False
,
dpi
=
255
)
print
(
f
"Saved png format into
{
png_path
}
"
)
if
'
svg
'
in
formats_list
:
if
"
svg
"
in
formats_list
:
svg_path
=
tmp_directory
/
f
"
{
name
}
.svg"
plt
.
savefig
(
svg_path
,
bbox_inches
=
"tight"
,
format
=
"svg"
)
print
(
f
"Saved svg format into
{
svg_path
}
"
)
tilelang/transform/__init__.py
View file @
29051439
...
...
@@ -110,8 +110,7 @@ def LowerHopperIntrin():
fpass : tvm.transform.Pass
The result pass
"""
return
(
_ffi_api
.
LowerHopperIntrin
()
if
hasattr
(
_ffi_api
,
"LowerHopperIntrin"
)
else
lambda
f
:
f
)
# type: ignore
return
_ffi_api
.
LowerHopperIntrin
()
if
hasattr
(
_ffi_api
,
"LowerHopperIntrin"
)
else
lambda
f
:
f
# type: ignore
def
WarpSpecializedPipeline
():
...
...
@@ -365,8 +364,7 @@ def FlattenBuffer():
def
EliminateStorageSyncForMBarrier
():
"""EliminateStorageSyncForMBarrier
"""
"""EliminateStorageSyncForMBarrier"""
return
_ffi_api
.
EliminateStorageSyncForMBarrier
()
# type: ignore
...
...
@@ -378,19 +376,16 @@ def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False, align_by
fpass : tvm.transform.Pass
The result pass
"""
return
_ffi_api
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
,
align_bytes
)
# type: ignore
return
_ffi_api
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
,
align_bytes
)
# type: ignore
def
LowerL2Persistent
():
"""LowerL2Persistent
"""
"""LowerL2Persistent"""
return
_ffi_api
.
LowerL2Persistent
()
# type: ignore
def
PersistThreadblock
():
"""PersistThreadblock
"""
"""PersistThreadblock"""
return
_ffi_api
.
PersistThreadblock
()
# type: ignore
...
...
@@ -409,8 +404,7 @@ def AlignDynamicSharedMemoryAllocations(align_bytes: int = 16):
def
LowerSharedBarrier
():
"""LowerSharedBarrier
"""
"""LowerSharedBarrier"""
return
_ffi_api
.
LowerSharedBarrier
()
# type: ignore
...
...
@@ -437,20 +431,17 @@ def StorageRewrite():
def
LowerOpaqueBlock
():
"""LowerOpaqueBlock
"""
"""LowerOpaqueBlock"""
return
_ffi_api
.
LowerOpaqueBlock
()
# type: ignore
def
LowerThreadAllreduce
():
"""LowerThreadAllreduce
"""
"""LowerThreadAllreduce"""
return
_ffi_api
.
LowerThreadAllreduce
()
# type: ignore
def
LowerIntrin
():
"""LowerIntrin
"""
"""LowerIntrin"""
return
_ffi_api
.
LowerIntrin
()
# type: ignore
...
...
@@ -468,8 +459,7 @@ def LowerDeviceKernelLaunch():
def
LowerSharedTmem
():
"""LowerSharedTmem
"""
"""LowerSharedTmem"""
return
_ffi_api
.
LowerSharedTmem
()
# type: ignore
...
...
tilelang/transform/add_bufstore_wrapper.py
View file @
29051439
from
tvm.tir
import
(
BufferStore
,
For
,
AttrStmt
,
ForKind
,
Var
,
PrimFunc
,
BufferLoad
,
Buffer
,
IntImm
)
from
tvm.tir
import
BufferStore
,
For
,
AttrStmt
,
ForKind
,
Var
,
PrimFunc
,
BufferLoad
,
Buffer
,
IntImm
from
tvm.tir.stmt_functor
import
ir_transform
,
post_order_visit
from
tvm.tir.transform
import
prim_func_pass
...
...
@@ -97,7 +97,7 @@ def AddWrapperForSingleBufStore():
Returns:
True if the loop is a tile operation (parallel or has num_stages annotation)
"""
return
loop
.
kind
==
ForKind
.
PARALLEL
or
'
num_stages
'
in
loop
.
annotations
return
loop
.
kind
==
ForKind
.
PARALLEL
or
"
num_stages
"
in
loop
.
annotations
def
pre_visit
(
statement
):
"""
...
...
@@ -105,7 +105,7 @@ def AddWrapperForSingleBufStore():
"""
nonlocal
tile_operation_depth
if
isinstance
(
statement
,
AttrStmt
)
and
statement
.
attr_key
==
'
thread_extent
'
:
if
isinstance
(
statement
,
AttrStmt
)
and
statement
.
attr_key
==
"
thread_extent
"
:
thread_binding_vars
.
add
(
statement
.
node
.
var
)
elif
isinstance
(
statement
,
For
)
and
is_tile_operation_loop
(
statement
):
tile_operation_depth
+=
1
...
...
@@ -139,7 +139,8 @@ def AddWrapperForSingleBufStore():
if
isinstance
(
index
,
IntImm
)
and
index
!=
0
:
raise
ValueError
(
f
"Fragment buffer access with non-zero index [
{
index
}
] is not supported. "
"Only fragment[0] access is allowed."
)
"Only fragment[0] access is allowed."
)
# Wrap fragment[0] access with T.Parallel loop
return
For
(
Var
(
"_"
,
"int32"
),
0
,
1
,
ForKind
.
PARALLEL
,
statement
)
...
...
tilelang/transform/pass_config.py
View file @
29051439
...
...
@@ -5,6 +5,7 @@ from enum import Enum
class
PassConfigKey
(
str
,
Enum
):
"""Pass configuration keys for TileLang compiler."""
# TileLang specific configs
TL_SIMPLIFY
=
"tl.Simplify"
"""Enable/disable TileLang simplification passes. Default: True"""
...
...
Prev
1
…
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment