Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
2c490782
"vscode:/vscode.git/clone" did not exist on "d1d48bdc1ffa2b2f05ecddd3724bc6d9944469ad"
Commit
2c490782
authored
Oct 28, 2025
by
Lukinon
Committed by
qisan
Oct 28, 2025
Browse files
[Feature] Add support for Hygon DCU
parent
7d389a43
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
606 additions
and
0 deletions
+606
-0
tilelang/engine/lower.py
tilelang/engine/lower.py
+1
-0
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+605
-0
No files found.
tilelang/engine/lower.py
View file @
2c490782
...
@@ -106,6 +106,7 @@ def tilelang_callback_hip_compile(code, target):
...
@@ -106,6 +106,7 @@ def tilelang_callback_hip_compile(code, target):
target_format
=
"hsaco"
,
target_format
=
"hsaco"
,
options
=
[
options
=
[
"-std=c++17"
,
"-std=c++17"
,
"-O1"
,
"-I"
+
tl_template_path
,
"-I"
+
tl_template_path
,
"-I"
+
ck_path
,
"-I"
+
ck_path
,
],
],
...
...
tilelang/intrinsics/mmac_macro_generator.py
0 → 100644
View file @
2c490782
from
tilelang
import
tvm
as
tvm
import
tilelang.language
as
T
from
typing
import
Tuple
from
tvm
import
DataType
from
tvm.tir
import
PrimExpr
from
tvm.runtime
import
convert
from
typing
import
Optional
from
.utils
import
(
mfma_store_index_map
,)
lift
=
convert
class
MatrixCoreIntrinEmitter
(
object
):
"""
To eliminate Python syntax within TIR Macro.
"""
M_DIM
=
16
N_DIM
=
16
WARP_SIZE
=
64
dtype_abbrv
=
{
"float16"
:
"fp16"
,
"bfloat16"
:
"bf16"
,
"float32"
:
"fp32"
,
"int8"
:
"int8"
,
"int32"
:
"int32"
,
"float8_e4m3"
:
"e4m3"
,
"float8_e5m2"
:
"e5m2"
,
"float8_e4m3fnuz"
:
"e4m3fnuz"
,
}
# k_pack represents the number of elements in a vectorized instruction
# Detail information can be found in the triton documentation
# https://github.com/triton-lang/triton/blob/433037206d8870f0b82a3cd669097001084a29ed/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp#L419
k_pack
=
1
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first
=
False
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
b_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
,
a_transposed
:
bool
=
False
,
b_transposed
:
bool
=
False
,
block_row_warps
:
int
=
2
,
block_col_warps
:
int
=
2
,
warp_row_tiles
:
int
=
8
,
warp_col_tiles
:
int
=
8
,
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
accum_dtype
=
accum_dtype
self
.
a_transposed
=
a_transposed
self
.
b_transposed
=
b_transposed
# Hint Information
self
.
block_row_warps
=
block_row_warps
self
.
block_col_warps
=
block_col_warps
self
.
warp_row_tiles
=
warp_row_tiles
self
.
warp_col_tiles
=
warp_col_tiles
self
.
chunk
=
chunk
self
.
_initialize_k_dim
(
a_dtype
)
self
.
_initialize_abbrev
(
a_dtype
,
b_dtype
,
accum_dtype
)
self
.
_initialize_local_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
,
self
.
WARP_SIZE
)
self
.
_initialize_mmac_prefix
(
self
.
k_dim
)
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
)
self
.
_initialize_k_pack
(
k_pack
)
self
.
_initialize_is_m_first
(
is_m_first
)
self
.
_initialize_b_preshuffle
(
b_preshuffle
)
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
num_elems_per_byte
=
num_elems_per_byte
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
if
isinstance
(
a_dtype
,
str
):
if
a_dtype
in
[
"float8_e4m3fnuz"
,
"int8"
]:
self
.
k_dim
=
32
return
a_dtype
=
DataType
(
a_dtype
)
if
a_dtype
.
bits
==
32
:
self
.
k_dim
=
4
elif
a_dtype
.
bits
in
{
16
,
8
}:
self
.
k_dim
=
16
else
:
raise
ValueError
(
f
"Unsupported a_dtype =
{
a_dtype
}
"
)
def
_initialize_local_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
,
warp_size
=
32
):
self
.
local_size_a
=
(
m_dim
*
k_dim
)
//
warp_size
self
.
local_size_b
=
(
n_dim
*
k_dim
)
//
warp_size
self
.
local_size_out
=
(
m_dim
*
n_dim
)
//
warp_size
def
_initialize_abbrev
(
self
,
a_dtype
,
b_dtype
,
accum_dtype
):
self
.
a_dtype_abbrv
=
self
.
dtype_abbrv
[
a_dtype
]
self
.
b_dtype_abbrv
=
self
.
dtype_abbrv
[
b_dtype
]
self
.
accum_dtype_abbrv
=
self
.
dtype_abbrv
[
accum_dtype
]
def
_initialize_mmac_prefix
(
self
,
k_dim
=
16
):
in_dtype
,
out_dtype
=
self
.
a_dtype
,
self
.
accum_dtype
M_DIM
,
N_DIM
=
self
.
M_DIM
,
self
.
N_DIM
out_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
}[
out_dtype
]
in_dtype_abbrv
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
}[
in_dtype
]
if
in_dtype_abbrv
==
"i8"
:
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}
_i8"
else
:
self
.
mmac_suffix
=
f
"
{
out_dtype_abbrv
}
_
{
M_DIM
}
x
{
N_DIM
}
x
{
k_dim
}{
in_dtype_abbrv
}
"
def
_initialize_micro_size
(
self
,
m_dim
=
16
,
n_dim
=
16
,
k_dim
=
16
):
self
.
micro_size_x
=
m_dim
self
.
micro_size_y
=
n_dim
self
.
micro_size_k
=
k_dim
def
_initialize_k_pack
(
self
,
k_pack
:
Optional
[
int
]
=
None
):
if
k_pack
is
not
None
:
self
.
k_pack
=
k_pack
def
_initialize_is_m_first
(
self
,
is_m_first
:
Optional
[
bool
]
=
False
):
if
is_m_first
is
not
None
:
self
.
is_m_first
=
is_m_first
def
_initialize_b_preshuffle
(
self
,
b_preshuffle
:
Optional
[
bool
]
=
False
):
if
b_preshuffle
is
not
None
:
self
.
b_preshuffle
=
b_preshuffle
def
get_ldmatrix_index_map
(
self
,
is_b
=
False
):
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_4x16_to_local_64x1_layout_B
,
shared_16x16_to_local_64x4_layout_A
,
shared_16x16_to_local_64x4_layout_B
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_B
,
shared_16x64_to_local_64x16_layout_A
,
shared_16x64_to_local_64x16_layout_B
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
thread_id_shared_access_64x16_to_16x64_layout_A
,
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
k_dim
=
self
.
k_dim
*
self
.
k_pack
transposed
=
self
.
a_transposed
if
not
is_b
else
self
.
b_transposed
if
k_dim
==
4
:
index_map
=
shared_16x4_to_local_64x1_layout_A
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
is_b
:
index_map
=
shared_16x4_to_local_64x1_layout_A
if
transposed
else
shared_4x16_to_local_64x1_layout_B
reverse_index_map
=
thread_id_shared_access_64x1_to_16x4_layout_A
if
transposed
else
thread_id_shared_access_64x1_to_4x16_layout_B
elif
k_dim
==
16
:
index_map
=
shared_16x16_to_local_64x4_layout_B
if
transposed
else
shared_16x16_to_local_64x4_layout_A
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_B
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_A
if
is_b
:
index_map
=
shared_16x16_to_local_64x4_layout_A
if
transposed
else
shared_16x16_to_local_64x4_layout_B
reverse_index_map
=
thread_id_shared_access_64x4_to_16x16_layout_A
if
transposed
else
thread_id_shared_access_64x4_to_16x16_layout_B
elif
k_dim
==
32
:
index_map
=
shared_16x32_to_local_64x8_layout_B
if
transposed
else
shared_16x32_to_local_64x8_layout_A
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_B
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_A
if
is_b
:
index_map
=
shared_16x32_to_local_64x8_layout_A
if
transposed
else
shared_16x32_to_local_64x8_layout_B
reverse_index_map
=
thread_id_shared_access_64x8_to_16x32_layout_A
if
transposed
else
thread_id_shared_access_64x8_to_16x32_layout_B
elif
k_dim
==
64
:
index_map
=
shared_16x64_to_local_64x16_layout_B
if
transposed
else
shared_16x64_to_local_64x16_layout_A
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_B
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_A
if
is_b
:
index_map
=
shared_16x64_to_local_64x16_layout_A
if
transposed
else
shared_16x64_to_local_64x16_layout_B
reverse_index_map
=
thread_id_shared_access_64x16_to_16x64_layout_A
if
transposed
else
thread_id_shared_access_64x16_to_16x64_layout_B
else
:
raise
ValueError
(
"k_dim must be 4 or 16 or 32 or 64 currently"
)
return
index_map
,
reverse_index_map
def
extract_thread_binding
(
self
,
thread_id
,
is_m_first
=
None
)
->
Tuple
[
PrimExpr
,
PrimExpr
,
PrimExpr
]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
'''
WARP_SIZE
=
self
.
WARP_SIZE
block_row_warps
=
self
.
block_row_warps
block_col_warps
=
self
.
block_col_warps
# if is_m_first is None, then use the default value
if
is_m_first
is
None
:
is_m_first
=
self
.
is_m_first
if
is_m_first
:
lane_id
,
warp_n
,
warp_m
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_col_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_col_warps
))
%
block_row_warps
,
return
lane_id
,
warp_n
,
warp_m
else
:
lane_id
,
warp_m
,
warp_n
=
thread_id
%
WARP_SIZE
,
(
thread_id
//
WARP_SIZE
)
%
block_row_warps
,
(
thread_id
//
(
WARP_SIZE
*
block_row_warps
))
%
block_col_warps
,
return
lane_id
,
warp_n
,
warp_m
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
,
ki
,
rk
=
0
):
warp_row_tiles
=
self
.
warp_row_tiles
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
micro_size_x
=
self
.
micro_size_x
micro_size_k
=
self
.
micro_size_k
local_size_a
=
self
.
local_size_a
k_pack
=
self
.
k_pack
is_transposed
=
self
.
a_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
False
)
@
T
.
macro
def
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
+
row
,
r
+
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
+
row
,
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_shared_buf
,
ki
,
rk
=
0
):
warp_col_tiles
=
self
.
warp_col_tiles
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
micro_size_y
=
self
.
micro_size_y
micro_size_k
=
self
.
micro_size_k
local_size_b
=
self
.
local_size_b
k_pack
=
self
.
k_pack
is_transposed
=
self
.
b_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
True
)
@
T
.
macro
def
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
((
tx
&
15
)
//
4
+
(
tx
&
3
)
*
4
+
(
tx
//
16
)
*
16
,
local_id
))
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
+
row
,
r
+
col
]
return
_warp_ldmatrix_b
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
)
def
mmac
(
self
,
A_local_buf
,
B_local_buf
,
C_local_buf
):
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_a
=
self
.
local_size_a
local_size_b
=
self
.
local_size_b
local_size_out
=
self
.
local_size_out
k_pack
=
self
.
k_pack
mmac_suffix
=
self
.
mmac_suffix
a_dtype
,
b_dtype
,
out_dtype
=
self
.
a_dtype
,
self
.
b_dtype
,
self
.
accum_dtype
compute_a_dtype
=
a_dtype
if
local_size_a
==
1
else
f
"
{
a_dtype
}
x
{
local_size_a
}
"
compute_b_dtype
=
b_dtype
if
local_size_b
==
1
else
f
"
{
b_dtype
}
x
{
local_size_b
}
"
compute_out_dtype
=
out_dtype
if
local_size_out
==
1
else
f
"
{
out_dtype
}
x
{
local_size_out
}
"
@
T
.
macro
def
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
):
for
kp
,
i
,
j
in
T
.
grid
(
k_pack
,
warp_rows
,
warp_cols
):
T
.
tvm_mmac
(
mmac_suffix
,
"row"
,
"row"
,
compute_a_dtype
,
compute_b_dtype
,
compute_out_dtype
,
A_local_buf
.
data
,
((
j
*
k_pack
+
kp
)
*
local_size_a
)
//
local_size_a
,
B_local_buf
.
data
,
((
i
*
k_pack
+
kp
)
*
local_size_b
)
//
local_size_b
,
C_local_buf
.
data
,
(
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
)
//
local_size_out
,
dtype
=
compute_out_dtype
,
)
return
_warp_mma
(
A_local_buf
,
B_local_buf
,
C_local_buf
)
def
stmatrix
(
self
,
C_local_buf
,
C_buf
,
pid_m
=
None
,
pid_n
=
None
):
block_row_warps
=
self
.
block_row_warps
block_col_warps
=
self
.
block_col_warps
warp_rows
=
self
.
warp_rows
warp_cols
=
self
.
warp_cols
local_size_out
=
self
.
local_size_out
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
BLOCK_M
=
block_row_warps
*
warp_rows
BLOCK_N
=
block_col_warps
*
warp_cols
M_DIM
,
N_DIM
=
self
.
M_DIM
,
self
.
N_DIM
C_buf_dims
=
len
(
C_buf
.
shape
)
assert
C_buf_dims
in
{
2
,
4
},
"C_buf should be 2D or 4D"
# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@
T
.
macro
def
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
):
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
if
C_buf_dims
==
2
:
C_buf
[(
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
j
*
(
warp_rows
*
local_size_out
)
+
i
*
local_size_out
+
local_id
]
else
:
C_buf
[
warp_n
*
warp_cols
+
j
,
warp_m
*
warp_rows
+
i
,
row
,
col
]
=
C_local_buf
[
j
*
warp_rows
*
local_size_out
+
i
*
local_size_out
+
local_id
]
@
T
.
macro
def
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
):
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
for
i
,
j
in
T
.
grid
(
warp_rows
,
warp_cols
):
for
local_id
in
T
.
vectorized
(
local_size_out
):
row
,
col
=
T
.
meta_var
(
mfma_store_index_map
(
tx
,
local_id
))
C_buf
[(
pid_m
*
BLOCK_M
+
warp_m
*
warp_rows
+
i
)
*
M_DIM
+
row
,
(
pid_n
*
BLOCK_N
+
warp_n
*
warp_cols
+
j
)
*
N_DIM
+
col
]
=
C_local_buf
[
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
local_id
]
return
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_binding
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_binding
)
class
MatrixCorePreshuffleIntrinEmitter
(
MatrixCoreIntrinEmitter
):
def
__init__
(
self
,
a_dtype
:
str
=
"float16"
,
b_dtype
:
str
=
"float16"
,
accum_dtype
:
str
=
"float16"
,
a_transposed
:
bool
=
False
,
b_transposed
:
bool
=
False
,
block_row_warps
:
int
=
2
,
block_col_warps
:
int
=
2
,
warp_row_tiles
:
int
=
8
,
warp_col_tiles
:
int
=
8
,
chunk
:
int
=
16
,
reduce_k
:
int
=
1
,
num_elems_per_byte
:
int
=
1
,
k_pack
:
Optional
[
int
]
=
None
,
is_m_first
:
Optional
[
bool
]
=
False
,
a_preshuffle
:
Optional
[
bool
]
=
False
,
b_preshuffle
:
Optional
[
bool
]
=
False
,
):
self
.
a_dtype
=
a_dtype
self
.
b_dtype
=
b_dtype
self
.
accum_dtype
=
accum_dtype
self
.
a_transposed
=
a_transposed
self
.
b_transposed
=
b_transposed
# Hint Information
self
.
block_row_warps
=
block_row_warps
self
.
block_col_warps
=
block_col_warps
self
.
warp_row_tiles
=
warp_row_tiles
self
.
warp_col_tiles
=
warp_col_tiles
self
.
chunk
=
chunk
self
.
_initialize_k_dim
(
a_dtype
)
self
.
_initialize_abbrev
(
a_dtype
,
b_dtype
,
accum_dtype
)
self
.
_initialize_local_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
,
self
.
WARP_SIZE
)
self
.
_initialize_mmac_prefix
(
self
.
k_dim
)
self
.
_initialize_micro_size
(
self
.
M_DIM
,
self
.
N_DIM
,
self
.
k_dim
)
self
.
_initialize_k_pack
(
k_pack
)
self
.
_initialize_is_m_first
(
is_m_first
)
self
.
_initialize_preshuffle
(
a_preshuffle
,
b_preshuffle
)
self
.
warp_rows
=
warp_row_tiles
//
self
.
micro_size_x
self
.
warp_cols
=
warp_col_tiles
//
self
.
micro_size_y
self
.
reduce_k
=
reduce_k
self
.
threads
=
(
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
)
self
.
num_elems_per_byte
=
num_elems_per_byte
def
_initialize_preshuffle
(
self
,
a_preshuffle
:
bool
,
b_preshuffle
:
bool
):
if
a_preshuffle
is
not
None
:
self
.
a_preshuffle
=
a_preshuffle
if
b_preshuffle
is
not
None
:
self
.
b_preshuffle
=
b_preshuffle
def
ldmatrix_a
(
self
,
A_local_buf
,
A_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_rows
=
self
.
warp_rows
chunk
=
self
.
chunk
micro_size_k
=
self
.
micro_size_k
local_size_a
=
self
.
local_size_a
k_pack
=
self
.
k_pack
is_transposed
=
self
.
a_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
False
)
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
# no preshuffle, use the default implementation
if
self
.
a_preshuffle
is
False
:
return
super
().
ldmatrix_a
(
A_local_buf
,
A_buf
,
ki
,
rk
)
def
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
(
pid_m
*
self
.
block_row_warps
+
warp_m
)
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
l
,
r
,
row
,
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
(
pid_m
*
self
.
block_row_warps
+
warp_m
)
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
l
,
r
,
row
,
col
]
@
T
.
macro
def
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_m
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_rows
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_a_shared
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
def
ldmatrix_b
(
self
,
B_local_buf
,
B_buf
,
ki
,
rk
=
0
,
pid_m
=
None
,
pid_n
=
None
):
warp_cols
=
self
.
warp_cols
chunk
=
self
.
chunk
micro_size_k
=
self
.
micro_size_k
local_size_b
=
self
.
local_size_b
k_pack
=
self
.
k_pack
is_transposed
=
self
.
b_transposed
current_frame
=
T
.
KernelLaunchFrame
.
Current
()
thread_binding
=
current_frame
.
get_thread_binding
()
_
,
reverse_index_map
=
self
.
get_ldmatrix_index_map
(
is_b
=
True
)
is_global
=
pid_m
is
not
None
and
pid_n
is
not
None
if
self
.
b_preshuffle
is
False
:
return
super
().
ldmatrix_b
(
B_local_buf
,
B_buf
,
ki
,
rk
,
pid_m
,
pid_n
)
@
T
.
macro
def
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
(
pid_n
*
self
.
block_col_warps
+
warp_n
)
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
(
pid_n
*
self
.
block_col_warps
+
warp_n
)
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_buf
[
l
,
r
,
row
,
col
]
@
T
.
macro
def
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_shared_buf
,
ki
,
thread_binding
,
rk
=
0
,
):
tx
,
warp_n
,
_
=
self
.
extract_thread_binding
(
thread_binding
)
if
is_transposed
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
/
4
+
(
tx
&
3
)
*
4
+
(
tx
/
16
)
*
16
),
local_id
))
l
,
r
=
(
warp_n
*
warp_cols
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(((
tx
&
15
)
/
4
+
(
tx
&
3
)
*
4
+
(
tx
/
16
)
*
16
),
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_cols
+
j
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_shared_buf
[
l
,
r
,
row
,
col
]
return
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
if
is_global
else
_warp_ldmatrix_b_shared
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment