Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
8bf752ae
Commit
8bf752ae
authored
Jan 12, 2025
by
LeiWang1999
Browse files
test fix
parent
549416f7
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
205 additions
and
128 deletions
+205
-128
testing/python/amd/test_tilelang_test_amd.py
testing/python/amd/test_tilelang_test_amd.py
+1
-0
testing/python/kernel/test_tilelang_gemm.py
testing/python/kernel/test_tilelang_gemm.py
+16
-1
testing/python/primitives/test_tilelang_primitives_mma.py
testing/python/primitives/test_tilelang_primitives_mma.py
+47
-33
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+2
-1
tilelang/intrinsics/mma_layout.py
tilelang/intrinsics/mma_layout.py
+32
-0
tilelang/intrinsics/mma_macro_generator.py
tilelang/intrinsics/mma_macro_generator.py
+78
-82
tilelang/language/__init__.py
tilelang/language/__init__.py
+1
-1
tilelang/layout/fragment.py
tilelang/layout/fragment.py
+12
-2
tilelang/primitives/gemm/__init__.py
tilelang/primitives/gemm/__init__.py
+1
-1
tilelang/primitives/gemm/gemm_mma.py
tilelang/primitives/gemm/gemm_mma.py
+6
-7
tilelang/utils/__init__.py
tilelang/utils/__init__.py
+8
-0
tilelang/utils/language.py
tilelang/utils/language.py
+1
-0
No files found.
testing/python/amd/test_tilelang_test_amd.py
View file @
8bf752ae
...
@@ -101,6 +101,7 @@ def run_gemm(
...
@@ -101,6 +101,7 @@ def run_gemm(
mod
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
mod
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
@
tilelang
.
testing
.
requires_rocm
def
test_gemm_f16f32f32_nt
():
def
test_gemm_f16f32f32_nt
():
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
run_gemm
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
,
k_pack
=
2
)
...
...
testing/python/kernel/test_tilelang_gemm.py
View file @
8bf752ae
...
@@ -84,6 +84,7 @@ def run_gemm(
...
@@ -84,6 +84,7 @@ def run_gemm(
num_stages
,
num_stages
,
num_threads
,
num_threads
,
)
)
mod
,
params
=
tl
.
lower
(
program
)
mod
,
params
=
tl
.
lower
(
program
)
mod
=
tl
.
Profiler
(
mod
,
params
,
[
2
],
tl
.
TensorSupplyType
.
Integer
)
mod
=
tl
.
Profiler
(
mod
,
params
,
[
2
],
tl
.
TensorSupplyType
.
Integer
)
...
@@ -299,4 +300,18 @@ def test_pad_f16f16f32_nn():
...
@@ -299,4 +300,18 @@ def test_pad_f16f16f32_nn():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
# tilelang.testing.main()
run_gemm
(
512
,
1024
,
768
,
False
,
True
,
"float16"
,
"float16"
,
"float16"
,
128
,
256
,
32
,
2
,
)
testing/python/primitives/test_tilelang_primitives_mma.py
View file @
8bf752ae
...
@@ -26,7 +26,7 @@ def matmul_ssr(
...
@@ -26,7 +26,7 @@ def matmul_ssr(
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
shared_scope
=
"shared"
# or "shared.dyn" for dynamic shared memory
import
tilelang.language
as
T
import
tilelang.language
as
T
@
T
.
prim_func
@
T
.
prim_func
...
@@ -36,8 +36,8 @@ def matmul_ssr(
...
@@ -36,8 +36,8 @@ def matmul_ssr(
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
...
@@ -85,9 +85,9 @@ def run_matmul_ssr(
...
@@ -85,9 +85,9 @@ def run_matmul_ssr(
num_stages
,
num_stages
,
num_threads
,
num_threads
,
)
)
print
(
program
)
mod
,
params
=
tl
.
lower
(
program
)
mod
,
params
=
tl
.
lower
(
program
)
mod
=
tl
.
Profiler
(
mod
,
params
,
[
2
],
tl
.
TensorSupplyType
.
Integer
)
mod
=
tl
.
Profiler
(
mod
,
params
,
[
2
],
tl
.
TensorSupplyType
.
Integer
)
print
(
mod
.
get_kernel_source
())
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
import
torch
import
torch
...
@@ -140,6 +140,7 @@ def matmul_rsr(
...
@@ -140,6 +140,7 @@ def matmul_rsr(
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
A_local_shape
=
A_shared_shape
A_local_shape
=
A_shared_shape
shared_scope
=
"shared"
# or "shared.dyn" for dynamic shared memory
import
tilelang.language
as
T
import
tilelang.language
as
T
@
T
.
prim_func
@
T
.
prim_func
...
@@ -149,23 +150,23 @@ def matmul_rsr(
...
@@ -149,23 +150,23 @@ def matmul_rsr(
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
,
scope
=
shared_scope
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
else
:
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
A_shared
,
A_local
)
if
trans_B
:
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
copy
(
A_shared
,
A_local
)
P
.
gemm
(
A_local
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
P
.
gemm
(
A_local
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
# T.gemm(A_local, B_shared, C_local, trans_A, trans_B)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
return
main
...
@@ -203,6 +204,7 @@ def run_matmul_rsr(
...
@@ -203,6 +204,7 @@ def run_matmul_rsr(
)
)
mod
,
params
=
tl
.
lower
(
program
)
mod
,
params
=
tl
.
lower
(
program
)
mod
=
tl
.
Profiler
(
mod
,
params
,
[
2
],
tl
.
TensorSupplyType
.
Integer
)
mod
=
tl
.
Profiler
(
mod
,
params
,
[
2
],
tl
.
TensorSupplyType
.
Integer
)
print
(
mod
.
get_kernel_source
())
def
ref_program
(
A
,
B
):
def
ref_program
(
A
,
B
):
import
torch
import
torch
...
@@ -218,22 +220,24 @@ def run_matmul_rsr(
...
@@ -218,22 +220,24 @@ def run_matmul_rsr(
mod
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
mod
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_gemm_f16f16f16_nt_rsr
():
# TODO(lei): Fix the test case in future release
run_matmul_rsr
(
# Now it has some bugs related to is_m_first
1024
,
# def test_gemm_f16f16f16_nt_rsr():
1024
,
# run_matmul_rsr(
1024
,
# 1024,
False
,
# 1024,
True
,
# 1024,
"float16"
,
# False,
"float16"
,
# True,
"float16"
,
# "float16",
16
,
# "float16",
16
,
# "float16",
16
,
# 128,
0
,
# 128,
num_threads
=
32
,
# 32,
)
# 0,
# num_threads=128,
# )
def
matmul_rrr
(
def
matmul_rrr
(
...
@@ -338,8 +342,25 @@ def run_matmul_rrr(
...
@@ -338,8 +342,25 @@ def run_matmul_rrr(
mod
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
mod
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_gemm_f16f16f16_nt_rrr
():
# def test_gemm_f16f16f16_nt_rrr():
run_matmul_rrr
(
# run_matmul_rrr(
# 1024,
# 1024,
# 1024,
# False,
# True,
# "float16",
# "float16",
# "float16",
# 128,
# 128,
# 32,
# 2,
# )
if
__name__
==
"__main__"
:
# tilelang.testing.main()
run_matmul_ssr
(
1024
,
1024
,
1024
,
1024
,
1024
,
1024
,
...
@@ -353,10 +374,3 @@ def test_gemm_f16f16f16_nt_rrr():
...
@@ -353,10 +374,3 @@ def test_gemm_f16f16f16_nt_rrr():
32
,
32
,
2
,
2
,
)
)
if
__name__
==
"__main__"
:
# tilelang.testing.main()
# test_gemm_f16f16f16_nt_ssr()
test_gemm_f16f16f16_nt_rsr
()
# test_gemm_f16f16f16_nt_rrr()
tilelang/intrinsics/mfma_macro_generator.py
View file @
8bf752ae
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Licensed under the MIT License.
import
tvm.tl.language
as
T
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
from
typing
import
Tuple
from
typing
import
Tuple
from
tvm
import
DataType
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
from
tvm.tir
import
PrimExpr
...
...
tilelang/intrinsics/mma_layout.py
View file @
8bf752ae
...
@@ -48,6 +48,38 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
...
@@ -48,6 +48,38 @@ def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
return
row
,
col
return
row
,
col
# sr represents spatial + reduction layout
# the first axis is spatial while the second axis is reduction
def
shared_16x16_to_mma_32x8_layout_sr
(
i
,
j
):
thread_id
=
4
*
(
i
%
8
)
+
(
j
%
8
)
//
2
return
thread_id
,
4
*
(
j
//
8
)
+
(
i
//
8
)
*
2
+
(
j
%
2
)
def
shared_16x16_to_mma_32x8_layout_rs
(
i
,
j
):
thread_id
=
4
*
(
j
%
8
)
+
(
i
%
8
)
//
2
return
thread_id
,
4
*
(
i
//
8
)
+
(
j
//
8
)
*
2
+
(
i
%
2
)
shared_16x16_to_mma_32x8_layout
=
shared_16x16_to_mma_32x8_layout_sr
shared_16x16_to_mma_32x8_layout_trans
=
shared_16x16_to_mma_32x8_layout_rs
def
shared_16x32_to_mma_32x16_layout
(
i
,
j
):
thread_id
=
4
*
(
i
%
8
)
+
(
j
%
16
)
//
4
return
thread_id
,
8
*
(
j
//
16
)
+
(
i
//
8
)
*
4
+
j
%
4
def
shared_32x16_to_mma_32x16_layout
(
i
,
j
):
thread_id
=
(
i
%
16
)
//
4
+
4
*
(
j
%
8
)
return
thread_id
,
8
*
(
j
//
8
)
+
(
i
//
16
)
*
4
+
i
%
4
def
mma_32x8_to_shared_16x16_layout
(
thread_id
,
local_id
):
row
=
8
*
(
local_id
%
4
//
2
)
+
(
thread_id
//
4
)
col
=
8
*
(
local_id
//
4
)
+
(
thread_id
%
4
)
*
2
+
(
local_id
%
2
)
return
row
,
col
def
shared_16x16_to_mma_32x8_smoothlayout
(
i
,
j
):
def
shared_16x16_to_mma_32x8_smoothlayout
(
i
,
j
):
return
(
i
*
2
+
j
//
8
,
j
%
8
)
return
(
i
*
2
+
j
//
8
,
j
%
8
)
...
...
tilelang/intrinsics/mma_macro_generator.py
View file @
8bf752ae
...
@@ -11,6 +11,7 @@ from .utils import (
...
@@ -11,6 +11,7 @@ from .utils import (
mma_store_index_map
,
mma_store_index_map
,
get_ldmatrix_offset
,
get_ldmatrix_offset
,
)
)
from
tilelang.utils
import
is_fragment
lift
=
convert
lift
=
convert
...
@@ -97,7 +98,7 @@ class TensorCoreIntrinEmitter(object):
...
@@ -97,7 +98,7 @@ class TensorCoreIntrinEmitter(object):
self
.
b_dtype_abbrv
=
self
.
dtype_abbrv
[
b_dtype
]
self
.
b_dtype_abbrv
=
self
.
dtype_abbrv
[
b_dtype
]
self
.
accum_dtype_abbrv
=
self
.
dtype_abbrv
[
accum_dtype
]
self
.
accum_dtype_abbrv
=
self
.
dtype_abbrv
[
accum_dtype
]
def
_initialize_mma_prefix
(
self
,
k_dim
=
16
):
def
_initialize_mma_prefix
(
self
,
k_dim
:
int
=
16
):
if
k_dim
==
16
:
if
k_dim
==
16
:
self
.
mma_prefix
=
"m16n8k16"
self
.
mma_prefix
=
"m16n8k16"
elif
k_dim
==
32
:
elif
k_dim
==
32
:
...
@@ -105,7 +106,7 @@ class TensorCoreIntrinEmitter(object):
...
@@ -105,7 +106,7 @@ class TensorCoreIntrinEmitter(object):
else
:
else
:
raise
ValueError
(
"Unsupported k_dim"
)
raise
ValueError
(
"Unsupported k_dim"
)
def
_initialize_micro_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
):
def
_initialize_micro_size
(
self
,
m_dim
:
int
=
16
,
n_dim
:
int
=
16
,
k_dim
:
int
=
16
):
self
.
micro_size_x
=
m_dim
self
.
micro_size_x
=
m_dim
self
.
micro_size_y
=
n_dim
self
.
micro_size_y
=
n_dim
self
.
micro_size_k
=
k_dim
self
.
micro_size_k
=
k_dim
...
@@ -122,9 +123,10 @@ class TensorCoreIntrinEmitter(object):
...
@@ -122,9 +123,10 @@ class TensorCoreIntrinEmitter(object):
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
inverse_index_map
=
index_map
.
inverse
([
warp_size
,
local_size_c
])
return
inverse_index_map
return
inverse_index_map
def
extract_thread_binding
(
self
,
def
extract_thread_binding
(
thread_id
,
self
,
is_m_first
=
None
)
->
Tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
thread_id
:
PrimExpr
,
is_m_first
:
Optional
[
bool
]
=
None
)
->
Tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
"""
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
...
@@ -153,7 +155,12 @@ class TensorCoreIntrinEmitter(object):
...
@@ -153,7 +155,12 @@ class TensorCoreIntrinEmitter(object):
)
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
,
ki
,
thread_bindings
,
rk
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
:
Buffer
,
A_shared_buf
:
Buffer
,
ki
:
PrimExpr
,
thread_bindings
:
PrimExpr
,
rk
:
Optional
[
PrimExpr
]
=
0
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
chunk
=
self
.
chunk
...
@@ -190,7 +197,12 @@ class TensorCoreIntrinEmitter(object):
...
@@ -190,7 +197,12 @@ class TensorCoreIntrinEmitter(object):
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_bindings
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_bindings
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_shared_buf
,
ki
,
thread_bindings
,
rk
=
0
):
def
ldmatrix_b
(
self
,
B_local_buf
:
Buffer
,
B_shared_buf
:
Buffer
,
ki
:
PrimExpr
,
thread_bindings
:
PrimExpr
,
rk
:
Optional
[
PrimExpr
]
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
chunk
=
self
.
chunk
...
@@ -232,7 +244,11 @@ class TensorCoreIntrinEmitter(object):
...
@@ -232,7 +244,11 @@ class TensorCoreIntrinEmitter(object):
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_bindings
,
rk
)
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_bindings
,
rk
)
def
mma
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
,
k_inner
=
0
):
def
mma
(
self
,
A_local_buf
:
Buffer
,
B_local_buf
:
Buffer
,
C_local_buf
:
Buffer
,
k_inner
:
Optional
[
PrimExpr
]
=
0
):
warp_rows
=
self
.
warp_rows
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_a
=
self
.
local_size_a
...
@@ -244,6 +260,11 @@ class TensorCoreIntrinEmitter(object):
...
@@ -244,6 +260,11 @@ class TensorCoreIntrinEmitter(object):
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
accum_dtype_abbrv
=
self
.
accum_dtype_abbrv
mma_prefix
=
self
.
mma_prefix
mma_prefix
=
self
.
mma_prefix
a_is_fragment
=
is_fragment
(
A_local_buf
)
b_is_fragment
=
is_fragment
(
B_local_buf
)
a_local_stride
:
PrimExpr
=
k_inner
*
warp_rows
*
local_size_a
if
a_is_fragment
else
0
b_local_stride
:
PrimExpr
=
k_inner
*
warp_cols
*
local_size_b
if
b_is_fragment
else
0
@
T
.
macro
@
T
.
macro
def
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
def
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
...
@@ -256,9 +277,9 @@ class TensorCoreIntrinEmitter(object):
...
@@ -256,9 +277,9 @@ class TensorCoreIntrinEmitter(object):
b_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
accum_dtype_abbrv
,
A_local_buf
.
data
,
A_local_buf
.
data
,
k_inner
*
warp_rows
*
local_s
ize_a
+
i
*
local_size_a
,
a_
local_s
tride
+
i
*
local_size_a
,
B_local_buf
.
data
,
B_local_buf
.
data
,
k_inner
*
warp_cols
*
local_s
ize_b
+
j
*
local_size_b
,
b_
local_s
tride
+
j
*
local_size_b
,
C_local_buf
.
data
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
,
T
.
bool
(
False
),
T
.
bool
(
False
),
...
@@ -273,9 +294,9 @@ class TensorCoreIntrinEmitter(object):
...
@@ -273,9 +294,9 @@ class TensorCoreIntrinEmitter(object):
b_dtype_abbrv
,
b_dtype_abbrv
,
accum_dtype_abbrv
,
accum_dtype_abbrv
,
A_local_buf
.
data
,
A_local_buf
.
data
,
k_inner
*
warp_rows
*
local_s
ize_a
+
i
*
local_size_a
,
a_
local_s
tride
+
i
*
local_size_a
,
B_local_buf
.
data
,
B_local_buf
.
data
,
k_inner
*
warp_cols
*
local_s
ize_b
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
b_
local_s
tride
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
T
.
bool
(
False
),
T
.
bool
(
False
),
...
@@ -352,105 +373,85 @@ class TensorCoreIntrinEmitter(object):
...
@@ -352,105 +373,85 @@ class TensorCoreIntrinEmitter(object):
AssertionError
AssertionError
If `local_buf` is not detected to be a fragment buffer.
If `local_buf` is not detected to be a fragment buffer.
"""
"""
from
tilelang.
primitives.
utils
import
is_fragment
from
tilelang.utils
import
is_fragment
from
tilelang.intrinsics.mma_layout
import
(
from
tilelang.intrinsics.mma_layout
import
(
ldmatrix_32x8_to_
shared_16x16_layout
,
shared_16x16_
to_mma_32x8_
layout
_sr
,
ldmatrix_trans_32x8_to_
shared_16x16_layout
,
shared_16x16_
to_mma_32x8_
layout
_rs
,
ldmatrix
_16x32_to_
shared_16x32
_layout
_a
,
shared
_16x32_to_
mma_32x16
_layout
,
ldmatrix_16x32_to_shared_16x32
_layout
_b
,
shared_32x16_to_mma_32x16
_layout
,
)
)
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
assert
matrix
in
[
"A"
,
"B"
],
"matrix should be either A or B"
dtype
=
self
.
a_dtype
if
matrix
==
"A"
else
self
.
b_dtype
dtype
=
self
.
a_dtype
if
matrix
==
"A"
else
self
.
b_dtype
dtype_bits
=
DataType
(
dtype
).
bits
dtype_bits
=
DataType
(
dtype
).
bits
transposed
=
self
.
a_transposed
transposed
=
self
.
a_transposed
transform_func
:
Callable
=
None
assert
transposed
is
False
,
"transposed is not supported yet"
transform_func_trans
:
Callable
=
None
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr
:
Callable
=
None
transform_func_rs
:
Callable
=
None
if
dtype_bits
==
16
:
if
dtype_bits
==
16
:
transform_func
=
ldmatrix_32x8_to_
shared_16x16_layout
transform_func
_sr
=
shared_16x16_
to_mma_32x8_
layout
_sr
transform_func_
trans
=
ldmatrix_trans_32x8_to_
shared_16x16_layout
transform_func_
rs
=
shared_16x16_
to_mma_32x8_
layout
_rs
elif
dtype_bits
==
8
:
elif
dtype_bits
==
8
:
if
matrix
==
"B"
and
transposed
:
transform_func_sr
=
shared_16x32_to_mma_32x16_layout
transform_func
=
ldmatrix_16x32_to_shared_16x32_layout_b
transform_func_rs
=
shared_32x16_to_mma_32x16_layout
elif
matrix
==
"A"
and
not
transposed
:
transform_func
=
ldmatrix_16x32_to_shared_16x32_layout_a
else
:
raise
ValueError
(
"ldmatrix only supports B transposed and A non-transposed for int8"
)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
is_sr_conditions
=
[
False
]
is_sr_conditions
.
append
(
matrix
==
"A"
and
not
transposed
)
is_sr_conditions
.
append
(
matrix
==
"B"
and
transposed
)
is_sr_axis_order
=
any
(
is_sr_conditions
)
transform_func
:
Callable
=
transform_func_sr
if
is_sr_axis_order
else
transform_func_rs
shape
=
local_buf
.
shape
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment, but got {}"
.
format
(
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment, but got {}"
.
format
(
local_buf
.
scope
())
local_buf
.
scope
())
if
matrix
==
"A"
:
if
matrix
==
"A"
:
micro_size_
x
,
micro_size_
y
=
self
.
micro_size_x
,
self
.
micro_size_k
micro_size_
s
,
micro_size_
r
=
self
.
micro_size_x
,
self
.
micro_size_k
else
:
else
:
micro_size_x
,
micro_size_y
=
self
.
micro_size_k
,
self
.
micro_size_y
micro_size_r
,
micro_size_s
=
self
.
micro_size_k
,
self
.
micro_size_y
if
transposed
:
micro_size_x
,
micro_size_y
=
micro_size_y
,
micro_size_x
local_size_out
=
self
.
local_size_out
block_row_warps
,
block_col_warps
=
(
block_row_warps
,
block_col_warps
=
(
self
.
block_row_warps
,
self
.
block_row_warps
,
self
.
block_col_warps
,
self
.
block_col_warps
,
)
)
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_size
=
self
.
WARP_SIZE
warp_s
=
warp_rows
if
matrix
==
"A"
else
warp_cols
is_m_first
=
self
.
is_m_first
chunk
=
self
.
chunk
transform_func
=
transform_func
if
not
transposed
else
transform_func_trans
transform_func
=
transform_func
warp_size
,
local_size_a
,
local_size_b
=
self
.
WARP_SIZE
,
self
.
local_size_a
,
self
.
local_size_b
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
)
local_size
=
local_size_a
if
matrix
==
"A"
else
local_size_b
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
).
inverse
([
warp_size
,
local_size
])
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
"""
Given the row index `i` and column index `j` in the fragment,
Given the row index `i` and column index `j` in the fragment,
map them to a thread index according to `inverse_mma_store_layout`.
"""
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
lane_id
,
_
=
inverse_mma_load_layout
.
map_indices
([
i
,
j
])
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
return
lane_id
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
if
is_m_first
:
thread_id
=
(
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
)
else
:
thread_id
=
(
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
)
return
thread_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
"""
"""
Given the row index `i` and column index `j` in the fragment,
Given the row index `i` and column index `j` in the fragment,
map them to a local index in a single thread according
to `inverse_mma_store_layout`.
"""
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
_
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
i
,
j
])
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
return
local_id
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
return
(
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
)
fragment
=
T
.
Fragment
(
base_
fragment
=
T
.
Fragment
(
shape
,
[
micro_size_r
,
micro_size_s
]
,
forward_thread_fn
=
forward_thread
,
forward_thread_fn
=
forward_thread
,
forward_index_fn
=
forward_index
,
forward_index_fn
=
forward_index
,
)
)
print
(
f
"fragment.shape:
{
local_buf
.
shape
}
"
)
warp_fragment
=
base_fragment
.
repeat
([
block_row_warps
,
1
],
print
(
f
"fragment.thread:
{
fragment
.
thread
}
"
)
repeat_on_thread
=
True
).
replicate
(
block_col_warps
)
print
(
f
"fragment.index:
{
fragment
.
index
}
"
)
block_fragment
=
warp_fragment
.
repeat
([
warp_s
,
chunk
//
micro_size_r
],
return
fragment
repeat_on_thread
=
False
,
lower_dim_first
=
False
)
print
(
f
"base_fragment:
{
base_fragment
}
"
)
print
(
f
"warp_fragment:
{
warp_fragment
}
"
)
print
(
f
"block_fragment:
{
block_fragment
}
"
)
return
block_fragment
def
make_mma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
def
make_mma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
"""
"""
...
@@ -474,7 +475,7 @@ class TensorCoreIntrinEmitter(object):
...
@@ -474,7 +475,7 @@ class TensorCoreIntrinEmitter(object):
AssertionError
AssertionError
If `local_buf` is not detected to be a fragment buffer.
If `local_buf` is not detected to be a fragment buffer.
"""
"""
from
tilelang.
primitives.
utils
import
is_fragment
from
tilelang.utils
import
is_fragment
shape
=
local_buf
.
shape
shape
=
local_buf
.
shape
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
inverse_mma_store_layout
=
self
.
get_store_index_map
(
inverse
=
True
)
...
@@ -494,14 +495,11 @@ class TensorCoreIntrinEmitter(object):
...
@@ -494,14 +495,11 @@ class TensorCoreIntrinEmitter(object):
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
lane_id
,
_
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
if
is_m_first
:
if
is_m_first
:
thread_id
=
block_i
*
(
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_size
+
lane_id
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
else
:
else
:
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
return
thread_id
return
thread_id
...
@@ -513,8 +511,6 @@ class TensorCoreIntrinEmitter(object):
...
@@ -513,8 +511,6 @@ class TensorCoreIntrinEmitter(object):
to `inverse_mma_store_layout`.
to `inverse_mma_store_layout`.
"""
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
...
...
tilelang/language/__init__.py
View file @
8bf752ae
...
@@ -8,7 +8,7 @@ from tvm.script.parser.tir import *
...
@@ -8,7 +8,7 @@ from tvm.script.parser.tir import *
from
tilelang.layout
import
Layout
,
Fragment
# noqa: F401
from
tilelang.layout
import
Layout
,
Fragment
# noqa: F401
from
.parallel
import
Parallel
# noqa: F401
from
.parallel
import
Parallel
# noqa: F401
from
.pipeline
import
Pipelined
# noqa: F401
from
.pipeline
import
Pipelined
# noqa: F401
from
.kernel
import
Kernel
# noqa: F401
from
.kernel
import
Kernel
,
KernelLaunchFrame
# noqa: F401
from
.allocate
import
(
from
.allocate
import
(
alloc_local
,
# noqa: F401
alloc_local
,
# noqa: F401
alloc_shared
,
# noqa: F401
alloc_shared
,
# noqa: F401
...
...
tilelang/layout/fragment.py
View file @
8bf752ae
...
@@ -30,6 +30,7 @@ class Fragment(Layout):
...
@@ -30,6 +30,7 @@ class Fragment(Layout):
else
:
else
:
thread_replicate
=
None
thread_replicate
=
None
forward_thread
=
forward_thread_fn
(
*
vars
)
forward_thread
=
forward_thread_fn
(
*
vars
)
self
.
__init_handle_by_constructor__
(
self
.
__init_handle_by_constructor__
(
_ffi_api
.
Fragment
,
_ffi_api
.
Fragment
,
forward_vars
,
forward_vars
,
...
@@ -45,12 +46,21 @@ class Fragment(Layout):
...
@@ -45,12 +46,21 @@ class Fragment(Layout):
def
get_thread_size
(
self
):
def
get_thread_size
(
self
):
return
_ffi_api
.
Fragment_thread_size
(
self
)
return
_ffi_api
.
Fragment_thread_size
(
self
)
def
repeat
(
self
,
repeats
,
repeat_on_thread
:
bool
=
False
)
->
"Fragment"
:
def
repeat
(
self
,
return
_ffi_api
.
Fragment_repeat
(
self
,
repeats
,
repeat_on_thread
)
repeats
,
repeat_on_thread
:
bool
=
False
,
lower_dim_first
:
bool
=
True
)
->
"Fragment"
:
return
_ffi_api
.
Fragment_repeat
(
self
,
repeats
,
repeat_on_thread
,
lower_dim_first
)
def
replicate
(
self
,
replicate
:
int
)
->
"Fragment"
:
return
_ffi_api
.
Fragment_replicate
(
self
,
replicate
)
def
condense_rep_var
(
self
)
->
"Fragment"
:
def
condense_rep_var
(
self
)
->
"Fragment"
:
return
_ffi_api
.
Fragment_condense_rep_var
(
self
)
return
_ffi_api
.
Fragment_condense_rep_var
(
self
)
def
__repr__
(
self
):
return
f
"Fragment<thread=
{
self
.
thread
}
, index=
{
self
.
index
}
>"
def
make_swizzled_layout
(
buffer
:
tvm
.
tir
.
Buffer
):
def
make_swizzled_layout
(
buffer
:
tvm
.
tir
.
Buffer
):
assert
len
(
buffer
.
shape
)
==
2
assert
len
(
buffer
.
shape
)
==
2
...
...
tilelang/primitives/gemm/__init__.py
View file @
8bf752ae
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
typing
import
Optional
from
typing
import
Optional
from
tvm
import
tir
from
tvm
import
tir
from
tilelang.
primitives.
utils
import
is_local
,
is_fragment
,
is_shared
from
tilelang.utils
import
is_local
,
is_fragment
,
is_shared
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
from
tilelang.primitives.gemm.gemm_mma
import
(
from
tilelang.primitives.gemm.gemm_mma
import
(
GemmPrimitiveMMA
,)
GemmPrimitiveMMA
,)
...
...
tilelang/primitives/gemm/gemm_mma.py
View file @
8bf752ae
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Licensed under the MIT License.
from
__future__
import
annotations
from
typing
import
Optional
,
Dict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
tvm
import
tir
from
tvm
import
tir
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.
primitives.
utils
import
is_fragment
,
array_reduce
from
tilelang.utils
import
is_fragment
from
tilelang.primitives.gemm.base
import
GemmBaseParams
from
tilelang.primitives.gemm.base
import
GemmBaseParams
from
tilelang.intrinsics.mma_macro_generator
import
TensorCoreIntrinEmitter
from
tilelang.intrinsics.mma_macro_generator
import
TensorCoreIntrinEmitter
...
@@ -39,9 +36,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -39,9 +36,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
)
->
tir
.
PrimExpr
:
)
->
tir
.
PrimExpr
:
in_dtype
=
self
.
in_dtype
in_dtype
=
self
.
in_dtype
warp_rows
=
mma_emitter
.
warp_rows
warp_cols
=
mma_emitter
.
warp_cols
warp_cols
=
mma_emitter
.
warp_cols
local_size_a
=
mma_emitter
.
local_size_a
local_size_b
=
mma_emitter
.
local_size_b
local_size_b
=
mma_emitter
.
local_size_b
block_K
=
mma_emitter
.
chunk
block_K
=
mma_emitter
.
chunk
micro_size_k
=
mma_emitter
.
micro_size_k
micro_size_k
=
mma_emitter
.
micro_size_k
...
@@ -71,6 +66,10 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -71,6 +66,10 @@ class GemmPrimitiveMMA(GemmBaseParams):
C_local
:
mma_emitter
.
make_mma_store_layout
(
C_local
),
C_local
:
mma_emitter
.
make_mma_store_layout
(
C_local
),
})
})
# Make default swizzle layout for shared memory
# T.annotate_layout({
# B_shared: make_mma_swizzle_layout(B_shared),
# })
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
for
ki
in
T
.
serial
(
0
,
(
block_K
//
micro_size_k
)):
# Load B into fragment
# Load B into fragment
...
@@ -197,7 +196,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
...
@@ -197,7 +196,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
"""
"""
# Infer block partition if necessary
# Infer block partition if necessary
current_frame
=
T
.
kernel
.
KernelLaunchFrame
.
Current
()
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
threads
=
current_frame
.
num_threads
threads
=
current_frame
.
num_threads
self
.
infer_block_partition
(
threads
)
self
.
infer_block_partition
(
threads
)
...
...
tilelang/utils/__init__.py
View file @
8bf752ae
...
@@ -5,3 +5,11 @@
...
@@ -5,3 +5,11 @@
from
.target
import
determine_target
# noqa: F401
from
.target
import
determine_target
# noqa: F401
from
.profiler
import
Profiler
# noqa: F401
from
.profiler
import
Profiler
# noqa: F401
from
.tensor
import
TensorSupplyType
,
torch_assert_close
# noqa: F401
from
.tensor
import
TensorSupplyType
,
torch_assert_close
# noqa: F401
from
.language
import
(
is_global
,
# noqa: F401
is_shared
,
# noqa: F401
is_shared_dynamic
,
# noqa: F401
is_fragment
,
# noqa: F401
is_local
,
# noqa: F401
array_reduce
,
# noqa: F401
)
tilelang/
primitives/utils
.py
→
tilelang/
utils/language
.py
View file @
8bf752ae
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
tvm.tir
import
Buffer
from
tvm.tir
import
Buffer
from
typing
import
List
from
typing
import
List
from
functools
import
reduce
from
functools
import
reduce
# Scope Checkers for TVM Buffers
# Scope Checkers for TVM Buffers
# These utility functions check the memory scope of a given TVM buffer.
# These utility functions check the memory scope of a given TVM buffer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment