Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
3852d58b
Commit
3852d58b
authored
Apr 03, 2026
by
wangziyang
Browse files
update cp_async & init inject_ds_read
parent
19cdf0ca
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
619 additions
and
40 deletions
+619
-40
examples/gemm/example_gemm_small.py
examples/gemm/example_gemm_small.py
+65
-0
src/layout/gemm_layouts.cc
src/layout/gemm_layouts.cc
+91
-1
src/layout/layout.h
src/layout/layout.h
+5
-0
src/op/builtin.cc
src/op/builtin.cc
+6
-0
src/op/builtin.h
src/op/builtin.h
+10
-0
src/op/gemm.cc
src/op/gemm.cc
+37
-13
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+44
-1
src/target/utils.cc
src/target/utils.cc
+7
-0
src/tl_templates/dcu_hip/common.h
src/tl_templates/dcu_hip/common.h
+22
-1
src/tl_templates/dcu_hip/copy.h
src/tl_templates/dcu_hip/copy.h
+32
-0
src/tl_templates/dcu_hip/gemm.h
src/tl_templates/dcu_hip/gemm.h
+24
-8
src/transform/inject_ds_read.cc
src/transform/inject_ds_read.cc
+205
-0
src/transform/inject_pipeline.cc
src/transform/inject_pipeline.cc
+3
-0
tilelang/engine/phase.py
tilelang/engine/phase.py
+2
-0
tilelang/env.py
tilelang/env.py
+2
-1
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+10
-14
tilelang/language/builtin.py
tilelang/language/builtin.py
+38
-0
tilelang/language/gemm_op.py
tilelang/language/gemm_op.py
+2
-0
tilelang/tileop/gemm/__init__.py
tilelang/tileop/gemm/__init__.py
+6
-1
tilelang/tileop/gemm/gemm_mfma.py
tilelang/tileop/gemm/gemm_mfma.py
+8
-0
No files found.
examples/gemm/example_gemm_small.py
0 → 100644
View file @
3852d58b
import
tilelang
import
tilelang.language
as
T
from
tilelang
import
disable_cache
disable_cache
()
@
tilelang
.
jit
(
out_idx
=
[
-
1
])
def
matmul
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
dtype
=
T
.
float16
,
accum_dtype
=
T
.
float32
):
@
T
.
prim_func
def
gemm
(
A
:
T
.
Tensor
((
M
,
K
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
gemm
def
main
():
kernel
=
matmul
(
32
,
32
,
32
,
32
,
32
,
32
)
import
torch
a
=
torch
.
randn
(
32
,
32
).
cuda
().
half
()
b
=
torch
.
randn
(
32
,
32
).
cuda
().
half
()
c
=
kernel
(
a
,
b
)
ref_c
=
a
@
b
print
(
"c:"
)
print
(
c
)
print
(
"ref_c:"
)
print
(
ref_c
)
# torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print
(
"All check passed."
)
# Get CUDA Source
print
(
"CUDA Source:"
)
print
(
kernel
.
get_kernel_source
())
# benchmark
profiler
=
kernel
.
get_profiler
()
latency
=
profiler
.
do_bench
(
backend
=
"cupti"
)
# latency = profiler.do_bench()
print
(
f
"tilelang Latency:
{
latency
}
ms"
)
if
__name__
==
"__main__"
:
main
()
src/layout/gemm_layouts.cc
View file @
3852d58b
...
...
@@ -104,6 +104,20 @@ Fragment makeGemmFragmentC16x16CDNA() {
return
Fragment
({
i
,
j
},
{
index
},
forward_thread
,
rep
);
}
// Tiled layout for DCU: each thread handles consecutive data in shared memory
// This layout is compatible with ds_read_m32x16_b16 which reads continuous memory
Fragment
makeGemmFragmentC16x16CDNATiled
()
{
IterVar
i
=
make_itervar
(
"i"
,
16
);
IterVar
j
=
make_itervar
(
"j"
,
16
);
IterVar
rep
=
make_itervar
(
"rep"
,
1
);
// Tiled layout: thread ID = i*4+(j/4), each thread handles 4 consecutive columns
// forward_thread: threads are assigned by (4 columns)
// index: each thread handles 4 elements at column 0-3, 4-7, 8-11, 12-15
PrimExpr
forward_thread
=
i
*
4
+
FloorDiv
(
j
->
var
,
4
);
PrimExpr
index
=
FloorMod
(
j
->
var
,
4
);
return
Fragment
({
i
,
j
},
{
index
},
forward_thread
,
rep
);
}
Fragment
makeGemmFragmentC_F64
(
const
int
block_m
,
const
int
block_n
,
const
int
warp_m
,
const
int
warp_n
)
{
ICHECK
(
block_m
%
warp_m
==
0
);
...
...
@@ -165,11 +179,20 @@ Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
ICHECK
(
block_n
%
warp_n
==
0
);
ICHECK
(
warp_m
%
16
==
0
)
<<
"warp_m="
<<
warp_m
;
ICHECK
(
warp_n
%
16
==
0
)
<<
"warp_n="
<<
warp_n
;
auto
base_layout
=
makeGemmFragmentC16x16CDNA
()
->
Repeat
({
1
,
1
},
false
);
// Use tiled layout for DCU: compatible with ds_read_m32x16_b16
// auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto
base_layout
=
makeGemmFragmentC16x16CDNATiled
()
->
Repeat
({
1
,
1
},
false
);
auto
warp_layout
=
base_layout
->
Repeat
({
warp_m
/
16
,
warp_n
/
16
},
false
,
false
);
auto
block_layout
=
warp_layout
->
Repeat
({
block_m
/
warp_m
,
block_n
/
warp_n
},
true
,
false
);
LOG
(
INFO
)
<<
"FragmentC warp_m: "
<<
warp_m
;
LOG
(
INFO
)
<<
"FragmentC warp_n: "
<<
warp_n
;
LOG
(
INFO
)
<<
"FragmentC block_m: "
<<
block_m
;
LOG
(
INFO
)
<<
"FragmentC block_n: "
<<
block_n
;
LOG
(
INFO
)
<<
"FragmentC base_layout: "
<<
base_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentC warp_layout: "
<<
warp_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentC block_layout: "
<<
block_layout
->
DebugOutput
();
return
block_layout
;
}
...
...
@@ -265,6 +288,13 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
->
Repeat
({
block_n
/
warp_n
,
1
},
true
,
false
);
auto
block_layout
=
warp_layout
->
Repeat
({
warp_n
/
8
,
block_k
/
16
},
false
,
false
);
LOG
(
INFO
)
<<
"FragmentB warp_m: "
<<
warp_m
;
LOG
(
INFO
)
<<
"FragmentB warp_n: "
<<
warp_n
;
LOG
(
INFO
)
<<
"FragmentB block_m: "
<<
block_m
;
LOG
(
INFO
)
<<
"FragmentB block_n: "
<<
block_n
;
LOG
(
INFO
)
<<
"FragmentB base_layout: "
<<
base_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentB warp_layout: "
<<
warp_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentB block_layout: "
<<
block_layout
->
DebugOutput
();
return
block_layout
;
}
else
{
auto
base_layout
=
...
...
@@ -273,8 +303,16 @@ Fragment makeGemmFragmentB(const int block_m, const int block_n,
->
Repeat
({
1
,
block_n
/
warp_n
},
true
);
auto
block_layout
=
warp_layout
->
Repeat
({
block_k
/
16
,
warp_n
/
8
},
false
,
true
);
LOG
(
INFO
)
<<
"FragmentB warp_m: "
<<
warp_m
;
LOG
(
INFO
)
<<
"FragmentB warp_n: "
<<
warp_n
;
LOG
(
INFO
)
<<
"FragmentB block_m: "
<<
block_m
;
LOG
(
INFO
)
<<
"FragmentB block_n: "
<<
block_n
;
LOG
(
INFO
)
<<
"FragmentB base_layout: "
<<
base_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentB warp_layout: "
<<
warp_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentB block_layout: "
<<
block_layout
->
DebugOutput
();
return
block_layout
;
}
}
Fragment
makeGemmFragmentACDNA
(
const
int
block_m
,
const
int
block_n
,
...
...
@@ -314,6 +352,58 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
}
}
Fragment
makeGemmFragmentADCU
(
const
int
block_m
,
const
int
block_n
,
const
int
block_k
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
,
const
int
k_pack
,
bool
transposed
)
{
// assume not transposed
ICHECK
(
block_m
%
warp_m
==
0
);
ICHECK
(
block_n
%
warp_n
==
0
);
ICHECK
(
warp_m
%
16
==
0
);
const
int
mfma_k
=
k_pack
*
(
element_size
==
16
?
16
:
32
);
ICHECK
(
block_k
%
mfma_k
==
0
);
ICHECK
(
element_size
==
8
||
element_size
==
16
)
<<
"element bitwidth="
<<
element_size
;
if
(
transposed
)
{
auto
base_layout
=
element_size
==
16
?
makeGemmFragmentAB16x16CDNATransposed
(
k_pack
)
->
Repeat
(
{
1
,
1
},
false
,
false
)
:
makeGemmFragmentAB16x32CDNATransposed
(
k_pack
)
->
Repeat
(
{
1
,
1
},
false
,
false
);
auto
warp_layout
=
base_layout
->
Repeat
({
block_k
/
mfma_k
,
warp_m
/
16
},
false
,
true
);
auto
block_layout
=
warp_layout
->
Repeat
({
1
,
block_m
/
warp_m
},
true
,
true
)
->
Replicate
(
block_n
/
warp_n
);
LOG
(
INFO
)
<<
"FragmentA warp_m: "
<<
warp_m
;
LOG
(
INFO
)
<<
"FragmentA warp_n: "
<<
warp_n
;
LOG
(
INFO
)
<<
"FragmentA block_m: "
<<
block_m
;
LOG
(
INFO
)
<<
"FragmentA block_n: "
<<
block_n
;
LOG
(
INFO
)
<<
"FragmentA base_layout: "
<<
base_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentA warp_layout: "
<<
warp_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentA block_layout: "
<<
block_layout
->
DebugOutput
();
return
block_layout
;
}
else
{
auto
base_layout
=
element_size
==
16
?
makeGemmFragmentAB16x16CDNA
(
k_pack
)
->
Repeat
({
1
,
1
},
false
,
false
)
:
makeGemmFragmentAB16x32CDNA
(
k_pack
)
->
Repeat
({
1
,
1
},
false
,
false
);
auto
warp_layout
=
base_layout
->
Repeat
({
warp_m
/
16
,
block_k
/
mfma_k
},
false
,
false
);
auto
block_layout
=
warp_layout
->
Repeat
({
block_m
/
warp_m
,
1
},
true
,
true
)
->
Replicate
(
block_n
/
warp_n
);
LOG
(
INFO
)
<<
"FragmentA warp_m: "
<<
warp_m
;
LOG
(
INFO
)
<<
"FragmentA warp_n: "
<<
warp_n
;
LOG
(
INFO
)
<<
"FragmentA block_m: "
<<
block_m
;
LOG
(
INFO
)
<<
"FragmentA block_n: "
<<
block_n
;
LOG
(
INFO
)
<<
"FragmentA base_layout: "
<<
base_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentA warp_layout: "
<<
warp_layout
->
DebugOutput
();
LOG
(
INFO
)
<<
"FragmentA block_layout: "
<<
block_layout
->
DebugOutput
();
return
block_layout
;
}
}
Fragment
makeGemmFragment32x32
(
int
element_size
)
{
IterVar
i
=
make_itervar
(
"i"
,
32
);
IterVar
j
=
make_itervar
(
"j"
,
32
);
...
...
src/layout/layout.h
View file @
3852d58b
...
...
@@ -226,6 +226,11 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const
int
warp_n
,
const
int
element_size
,
const
int
k_pack
,
bool
transposed
=
false
);
Fragment
makeGemmFragmentADCU
(
const
int
block_m
,
const
int
block_n
,
const
int
block_k
,
const
int
warp_m
,
const
int
warp_n
,
const
int
element_size
,
const
int
k_pack
,
bool
transposed
=
false
);
// Default Memory Layout
Layout
makeGemmLayoutLinear
(
int
stride
,
int
continuous
);
Layout
makeGemmABLayoutPadded
(
int
stride
,
int
continuous
,
int
element_size
);
...
...
src/op/builtin.cc
View file @
3852d58b
...
...
@@ -216,6 +216,12 @@ TIR_DEFINE_TL_BUILTIN(tma_store_wait)
.
set_num_inputs
(
0
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
ds_read_vector
)
.
set_num_inputs
(
5
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
set_max_nreg
)
.
set_num_inputs
(
2
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
...
...
src/op/builtin.h
View file @
3852d58b
...
...
@@ -335,6 +335,16 @@ TVM_DLL const Op &tma_store_arrive();
*/
TVM_DLL
const
Op
&
tma_store_wait
();
/*!
* \brief DS read from shared memory to register
*
* ds_read_vector(dst, lds_base_ptr, m, n, offset)
*
* This is a tilelang intrinsic for DCU ds_read hardware instruction.
* Generated code will call tl::ds_read_vector.
*/
TVM_DLL
const
Op
&
ds_read_vector
();
/*!
* \brief Set reg hint for warp-specialized branched
*
...
...
src/op/gemm.cc
View file @
3852d58b
...
...
@@ -787,27 +787,51 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
results
.
Set
(
c_
,
fragment
->
BindThreadRange
(
thread_range
));
}
if
(
a_
.
scope
()
==
"shared"
||
a_
.
scope
()
==
"shared.dyn"
)
{
LOG
(
INFO
)
<<
"Using CDNA shared memory layout for A"
;
int
dim_A
=
a_
->
shape
.
size
();
if
(
TargetIsDCU
(
T
.
target
))
{
auto
shared_layout
=
makeGemmLayoutLinear
(
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]));
results
.
Set
(
a_
,
shared_layout
);
}
else
{
auto
shared_layout
=
makeGemmABLayoutCDNA
(
*
as_const_int
(
a_
->
shape
[
dim_A
-
2
]),
*
as_const_int
(
a_
->
shape
[
dim_A
-
1
]),
a_
->
dtype
.
bits
(),
kPack_
);
results
.
Set
(
a_
,
shared_layout
);
}
}
else
if
(
a_
.
scope
()
==
"local.fragment"
)
{
LOG
(
INFO
)
<<
"Using CDNA local fragment layout for A"
;
if
(
TargetIsDCU
){
auto
fragment
=
makeGemmFragmentADCU
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
a_
->
dtype
.
bits
(),
kPack_
,
transA_
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
else
{
auto
fragment
=
makeGemmFragmentACDNA
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
a_
->
dtype
.
bits
(),
kPack_
,
transA_
);
results
.
Set
(
a_
,
fragment
->
BindThreadRange
(
thread_range
));
}
}
else
{
ICHECK
(
0
);
}
if
(
b_
.
scope
()
==
"shared"
||
b_
.
scope
()
==
"shared.dyn"
)
{
LOG
(
INFO
)
<<
"Using CDNA shared memory layout for B"
;
int
dim_B
=
b_
->
shape
.
size
();
if
(
TargetIsDCU
(
T
.
target
))
{
auto
shared_layout
=
makeGemmLayoutLinear
(
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]));
results
.
Set
(
b_
,
shared_layout
);
}
else
{
auto
shared_layout
=
makeGemmABLayoutCDNA
(
*
as_const_int
(
b_
->
shape
[
dim_B
-
2
]),
*
as_const_int
(
b_
->
shape
[
dim_B
-
1
]),
b_
->
dtype
.
bits
(),
kPack_
);
results
.
Set
(
b_
,
shared_layout
);
}
}
else
if
(
b_
.
scope
()
==
"local.fragment"
)
{
LOG
(
INFO
)
<<
"Using CDNA local fragment layout for B"
;
auto
fragment
=
makeGemmFragmentB
(
m_
,
n_
,
k_
,
m_
/
warp_m
,
n_
/
warp_n
,
transB_
);
results
.
Set
(
b_
,
fragment
->
BindThreadRange
(
thread_range
));
...
...
src/target/codegen_hip.cc
View file @
3852d58b
...
...
@@ -759,6 +759,7 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t,
void
CodeGenTileLangHIP
::
VisitExpr_
(
const
CallNode
*
op
,
std
::
ostream
&
os
)
{
auto
print_extern_call_stmt
=
[
&
](
std
::
string
name
,
size_t
offset
=
0
)
{
printf
(
"[DEBUG VisitExpr_] Branch: print_extern_call_stmt -> %s
\n
"
,
name
.
c_str
());
this
->
PrintIndent
();
this
->
stream
<<
name
<<
"("
;
for
(
size_t
i
=
offset
;
i
<
op
->
args
.
size
();
i
++
)
{
...
...
@@ -768,7 +769,9 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
}
this
->
stream
<<
");
\n
"
;
};
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_cp_async
\n
"
);
std
::
string
dst
=
this
->
PrintExpr
(
op
->
args
[
0
]);
std
::
string
dst_offset
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
src
=
this
->
PrintExpr
(
op
->
args
[
2
]);
...
...
@@ -788,48 +791,75 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
<<
", "
<<
condition
<<
");
\n
"
;
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_commit_group
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_commit_group
\n
"
);
print_extern_call_stmt
(
"tl::cp_async_commit"
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_wait_group
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_wait_group
\n
"
);
int
n
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
func_name
=
"tl::cp_async_wait<"
+
std
::
to_string
(
n
)
+
">"
;
print_extern_call_stmt
(
func_name
,
1
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
create_barriers
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: create_barriers
\n
"
);
this
->
PrintIndent
();
int
barrier_count
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
std
::
string
barrier_name
=
"_mbarrier"
;
this
->
stream
<<
"__shared__ uint64_t "
<<
barrier_name
<<
"["
<<
barrier_count
<<
"];
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
get_mbarrier
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: get_mbarrier
\n
"
);
std
::
string
barrier_name
=
"_mbarrier"
;
std
::
string
barrier_id
=
this
->
PrintExpr
(
op
->
args
[
0
]);
os
<<
barrier_name
+
"["
+
barrier_id
+
"]"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_arrive_barrier
\n
"
);
print_extern_call_stmt
(
"tl::mbarrier_arrive"
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_init_barrier_thread_count
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_init_barrier_thread_count
\n
"
);
print_extern_call_stmt
(
"tl::mbarrier_init"
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier_expect_tx
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_arrive_barrier_expect_tx
\n
"
);
print_extern_call_stmt
(
"tl::mbarrier_arrive_expect_tx"
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async_barrier
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_cp_async_barrier
\n
"
);
print_extern_call_stmt
(
"tl::mbarrier_cp_async_arrive"
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
mbarrier_expect_tx
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: mbarrier_expect_tx
\n
"
);
print_extern_call_stmt
(
"tl::mbarrier_expect_tx"
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
mbarrier_wait_parity
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: mbarrier_wait_parity
\n
"
);
print_extern_call_stmt
(
"tl::mbarrier_wait"
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
ptx_stmatrix
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_stmatrix
\n
"
);
int
trans
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
num
=
Downcast
<
IntImm
>
(
op
->
args
[
1
])
->
value
;
std
::
string
func_name
=
"tl::ptx_stmatrix_x"
+
std
::
to_string
(
num
);
if
(
trans
==
1
)
func_name
+=
"_trans"
;
print_extern_call_stmt
(
func_name
,
2
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
wait_wgmma
()))
{
}
else
if
(
op
->
op
.
same_as
(
tl
::
ds_read_vector
())){
//ds_read_b64 %1, %2 offset:%3
// ds_read_m32x16_b16 %0, %1 offset:%2
printf
(
"[DEBUG VisitExpr_] Branch: ds_read_vector
\n
"
);
std
::
string
dst
=
this
->
PrintExpr
(
op
->
args
[
0
]);
std
::
string
lds_base_ptr
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
m
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
n
=
this
->
PrintExpr
(
op
->
args
[
3
]);
std
::
string
offset
=
this
->
PrintExpr
(
op
->
args
[
4
]);
this
->
PrintIndent
();
this
->
stream
<<
"tl::ds_read_vector<"
<<
m
<<
", "
<<
n
<<
", "
<<
offset
<<
">"
<<
"(*reinterpret_cast<float4_*>("
<<
dst
<<
"), "
<<
"reinterpret_cast<uintptr_t>("
<<
lds_base_ptr
<<
"));
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
wait_wgmma
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: wait_wgmma
\n
"
);
this
->
PrintIndent
();
int
num_mma
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
this
->
stream
<<
"tl::wait_wgmma<"
<<
std
::
to_string
(
num_mma
)
<<
">();
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
pack_b16
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: pack_b16
\n
"
);
os
<<
"__pack_half2("
<<
this
->
PrintExpr
(
op
->
args
[
0
])
<<
", "
<<
this
->
PrintExpr
(
op
->
args
[
1
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
__ldg
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: __ldg
\n
"
);
// HIP fallback: regular load
const
BufferLoadNode
*
bl
=
op
->
args
[
0
].
as
<
BufferLoadNode
>
();
ICHECK
(
bl
)
<<
"T.__ldg expects a BufferLoad as the first argument."
;
...
...
@@ -840,6 +870,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
auto
buffer_ref
=
this
->
GetBufferRef
(
op
->
dtype
,
buffer
,
base
);
os
<<
buffer_ref
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_fill_fragment
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tvm_fill_fragment
\n
"
);
need_mma_h_
=
true
;
ICHECK_EQ
(
op
->
args
.
size
(),
6U
);
os
<<
"nvcuda::wmma::fill_fragment("
;
...
...
@@ -850,6 +881,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this
->
PrintExpr
(
op
->
args
[
5
],
os
);
os
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_load_matrix_sync
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tvm_load_matrix_sync
\n
"
);
need_mma_h_
=
true
;
ICHECK_EQ
(
op
->
args
.
size
(),
8U
);
os
<<
"nvcuda::wmma::load_matrix_sync("
;
...
...
@@ -862,6 +894,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this
->
PrintExpr
(
op
->
args
[
6
],
os
);
os
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_store_matrix_sync
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tvm_store_matrix_sync
\n
"
);
need_mma_h_
=
true
;
ICHECK_EQ
(
op
->
args
.
size
(),
8U
);
os
<<
"nvcuda::wmma::store_matrix_sync("
;
...
...
@@ -879,6 +912,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
}
os
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_mma_sync
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tvm_mma_sync
\n
"
);
need_mma_h_
=
true
;
ICHECK_EQ
(
op
->
args
.
size
(),
8U
);
os
<<
"nvcuda::wmma::mma_sync("
;
...
...
@@ -889,6 +923,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os
<<
"]"
<<
((
i
<
3
)
?
", "
:
")"
);
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_bmma_sync
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tvm_bmma_sync
\n
"
);
need_mma_h_
=
true
;
ICHECK_EQ
(
op
->
args
.
size
(),
8U
);
os
<<
"nvcuda::wmma::bmma_sync("
;
...
...
@@ -899,6 +934,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os
<<
"]"
<<
((
i
<
3
)
?
", "
:
")"
);
}
}
else
if
(
op
->
op
.
same_as
(
tl
::
tvm_mfma
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tvm_mfma
\n
"
);
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
...
...
@@ -964,6 +1000,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer
.
register_rule
(
"{c_bias}"
,
c_bias
);
os
<<
replacer
.
rewrite
(
call_mfma_code
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
tvm_mmac
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tvm_mmac
\n
"
);
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
...
...
@@ -1029,8 +1066,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer
.
register_rule
(
"{c_bias}"
,
c_bias
);
os
<<
replacer
.
rewrite
(
call_mmac_code
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
thread_return
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: thread_return
\n
"
);
os
<<
"return"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
tl_gemm
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tl_gemm
\n
"
);
ICHECK
(
op
->
args
.
size
()
==
4
)
<<
"tl_gemm expects 4 arguments <op_instance, "
"A_ptr, B_ptr, C_ptr>, but got "
<<
op
->
args
.
size
();
...
...
@@ -1038,15 +1077,19 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
this
->
PrintCallExtern
(
GetType
(
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
)),
op_instance
->
value
,
op
->
args
,
true
,
os
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
tl_gemm_sp
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: tl_gemm_sp
\n
"
);
LOG
(
FATAL
)
<<
"tl_gemm_sp is not supported on HIP"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
loop_break
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: loop_break
\n
"
);
this
->
PrintIndent
();
this
->
stream
<<
"break;
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
no_set_max_nreg
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: no_set_max_nreg
\n
"
);
// HIP doesn't need explicit register management like CUDA
// This is a no-op for HIP
return
;
}
else
{
printf
(
"[DEBUG VisitExpr_] Branch: CodeGenC::VisitExpr_ (fallback)
\n
"
);
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
}
...
...
src/target/utils.cc
View file @
3852d58b
...
...
@@ -96,6 +96,13 @@ bool TargetHasAsyncCopy(Target target) {
if
(
TargetIsCuda
(
target
))
{
int
arch
=
GetArchInt
(
target
);
return
arch
>=
80
;
}
else
if
(
TargetIsDCU
(
target
))
{
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
tvm
::
ffi
::
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
return
mcpu
.
find
(
"gfx936"
)
==
0
;
}
else
{
return
false
;
}
}
else
if
(
TargetIsCDNA
(
target
))
{
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
tvm
::
ffi
::
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
...
...
src/tl_templates/dcu_hip/common.h
View file @
3852d58b
...
...
@@ -138,3 +138,24 @@ template <typename T> TL_DEVICE void AtomicAddx4(T *ref, const T val[4]) {
atomicAdd
(
&
ref
[
2
],
val
[
2
]);
atomicAdd
(
&
ref
[
3
],
val
[
3
]);
}
typedef
float
float4_
__attribute__
((
ext_vector_type
(
4
)));
typedef
float
float2_
__attribute__
((
ext_vector_type
(
2
)));
struct
half4
{
__half
x
;
__half
y
;
__half
z
;
__half
w
;
};
union
RegisterUnion
{
float4_
vector4
;
struct
{
float2_
vector_front
;
float2_
vector_rear
;
};
};
\ No newline at end of file
src/tl_templates/dcu_hip/copy.h
View file @
3852d58b
...
...
@@ -86,6 +86,36 @@ TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
}
}
template
<
int
M
,
int
N
,
int
offset
>
TL_DEVICE
void
ds_read_vector
(
float4_
&
dst
,
uint32_t
lds_base_ptr
)
{
if
constexpr
(
M
==
16
&&
N
==
32
)
{
const
int
offset_in_bytes
=
offset
*
sizeof
(
half_t
);
asm
volatile
(
"ds_read_m32x16_b16 %0, %1 offset:%2
\n\t
"
:
"+v"
(
dst
)
:
"v"
(
lds_base_ptr
),
"n"
(
offset_in_bytes
)
:
"memory"
);
}
else
if
constexpr
(
M
==
32
&&
N
==
16
)
{
const
int
offset_in_bytes0
=
offset
*
sizeof
(
half_t
);
const
int
offset_in_bytes1
=
offset_in_bytes0
+
4096
;
float2_
&
front
=
*
reinterpret_cast
<
float2_
*>
(
&
dst
);
float2_
&
rear
=
*
(
reinterpret_cast
<
float2_
*>
(
&
dst
)
+
1
);
asm
volatile
(
"ds_read_b64 %1, %2 offset:%3
\n\t
"
"ds_read_b64 %0, %2 offset:%4
\n\t
"
:
"+v"
(
rear
),
"+v"
(
front
)
:
"v"
(
lds_base_ptr
),
"n"
(
offset_in_bytes0
),
"n"
(
offset_in_bytes1
)
:
"memory"
);
}
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs_conditional
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
,
bool
cond
)
{
...
...
@@ -107,4 +137,6 @@ TL_DEVICE void cp_async_gs_conditional(void *lds_base_ptr,
}
}
}
// namespace tl
src/tl_templates/dcu_hip/gemm.h
View file @
3852d58b
...
...
@@ -66,7 +66,8 @@ template <> struct MfmaTraits<fp8_e4_t> {
// ref to bitblas/tl/mfma_macro_generator.py::kPack
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_n
,
int
num_warp_m
,
bool
TransposeA
,
bool
TransposeB
,
bool
clear_accum
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
>
typename
B_type
,
typename
C_type
,
typename
AccDataType
=
float
,
bool
use_swizzle
=
true
>
class
GemmTensorOp
{
public:
// static_assert(!clear_accum, "clear_accum=true is not supported yet");
...
...
@@ -147,13 +148,23 @@ public:
template
<
int
continuous
=
32
,
int
element_size
=
2
>
TL_DEVICE
static
constexpr
auto
make_swizzle_layout
(
const
int
row
,
const
int
col
)
{
// auto [n_row, n_col] =
// make_mfma_swizzle_layout<continuous, element_size>(row, col);
// return n_row * continuous + n_col;
if
constexpr
(
use_swizzle
)
{
auto
[
n_row
,
n_col
]
=
make_mfma_swizzle_layout
<
continuous
,
element_size
>
(
row
,
col
);
return
n_row
*
continuous
+
n_col
;
}
else
{
// 不使用 swizzle,直接 linear layout
return
make_layout_padded
<
continuous
,
element_size
>
(
row
,
col
).
second
+
make_layout_padded
<
continuous
,
element_size
>
(
row
,
col
).
first
*
continuous
;
}
}
static
TL_DEVICE
void
body
(
A_type
*
A_shared
,
B_type
*
B_shared
,
C_type
*
C_local
)
{
printf
(
"Executing GemmTensorOp dcu_hip body
\n
"
);
auto
tid
=
threadIdx
.
x
;
auto
warp_id
=
tid
/
warp_size
;
auto
warp_m
=
warp_id
/
block_col_warps
;
...
...
@@ -178,6 +189,10 @@ public:
B_type
B_local
[
warp_rows
*
kPack
*
local_size_b
];
A_type
A_local
[
warp_cols
*
kPack
*
local_size_a
];
// Get base pointers as byte pointers for ds_read
const
char
*
B_shared_bytes
=
reinterpret_cast
<
const
char
*>
(
B_shared
);
const
char
*
A_shared_bytes
=
reinterpret_cast
<
const
char
*>
(
A_shared
);
for
(
int
ki
=
0
;
ki
<
inner_k
;
ki
++
)
{
// Fetch B into register
for
(
int
i
=
0
;
i
<
warp_rows
;
i
++
)
{
...
...
@@ -257,6 +272,7 @@ public:
B_type
B_local
[
warp_rows
*
kPack
*
local_size_b
];
for
(
int
ki
=
0
;
ki
<
inner_k
;
ki
++
)
{
// Fetch B into register
for
(
int
i
=
0
;
i
<
warp_rows
;
i
++
)
{
...
...
@@ -302,21 +318,21 @@ namespace tl {
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
bool
clear_accum
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
typename
B_type
,
typename
C_type
,
bool
use_swizzle
=
false
>
TL_DEVICE
void
gemm_ss
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
clear_accum
,
kPack
,
A_type
,
B_type
,
C_type
>
;
clear_accum
,
kPack
,
A_type
,
B_type
,
C_type
,
float
,
use_swizzle
>
;
Compute
::
body
(
pA
,
pB
,
accum
);
}
template
<
int
M
,
int
N
,
int
K
,
int
num_warp_m
,
int
num_warp_n
,
bool
trans_A
,
bool
trans_B
,
bool
clear_accum
,
int
kPack
,
typename
A_type
,
typename
B_type
,
typename
C_type
>
typename
B_type
,
typename
C_type
,
bool
use_swizzle
=
false
>
TL_DEVICE
void
gemm_rs
(
A_type
*
pA
,
B_type
*
pB
,
C_type
*
accum
)
{
using
Compute
=
GemmTensorOp
<
M
,
N
,
K
,
num_warp_m
,
num_warp_n
,
trans_A
,
trans_B
,
clear_accum
,
kPack
,
A_type
,
B_type
,
C_type
>
;
clear_accum
,
kPack
,
A_type
,
B_type
,
C_type
,
float
,
use_swizzle
>
;
Compute
::
body_rs
(
pA
,
pB
,
accum
);
}
...
...
src/transform/inject_ds_read.cc
0 → 100644
View file @
3852d58b
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Replace shared memory BufferLoad with ds_read hardware instructions
* \file inject_ds_read.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
/*!
* \brief Check if the target is AMD DCU (gfx936, gfx942, etc.)
*/
bool
IsDCUTarget
(
const
IRModule
&
module
)
{
for
(
auto
&
p
:
module
->
functions
)
{
if
(
auto
*
prim_func
=
p
.
second
.
as
<
PrimFuncNode
>
())
{
if
(
auto
opt_target
=
prim_func
->
GetAttr
<
Target
>
(
"target"
))
{
Target
target
=
opt_target
.
value
();
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
tvm
::
ffi
::
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
// if mcpu start with "gfx936", it is DCU
return
mcpu
.
find
(
"gfx936"
)
==
0
;
}
}
}
}
return
false
;
}
class
DSReadInjector
:
public
StmtMutator
{
public:
Stmt
VisitStmt_
(
const
BufferStoreNode
*
store
)
final
{
// Check if the store is to a local register (not shared memory)
bool
is_local
=
store
->
buffer
.
scope
()
==
"local"
||
store
->
buffer
.
scope
()
==
"local.fragment"
;
if
(
!
is_local
)
{
return
StmtMutator
::
VisitStmt_
(
store
);
}
// Check if the value is a BufferLoad from shared memory
if
(
auto
*
load
=
store
->
value
.
as
<
BufferLoadNode
>
())
{
bool
is_shared_load
=
load
->
buffer
.
scope
()
==
"shared"
||
load
->
buffer
.
scope
()
==
"shared.dyn"
;
if
(
!
is_shared_load
)
{
return
StmtMutator
::
VisitStmt_
(
store
);
}
// Skip if indices are vectorized (contain Ramp expressions)
// ds_read is a scalar instruction, cannot handle vectorized indices
if
(
HasVectorizedIndices
(
store
->
indices
)
||
HasVectorizedIndices
(
load
->
indices
))
{
return
StmtMutator
::
VisitStmt_
(
store
);
}
// Check if the buffer is large enough for ds_read_vector
// ds_read_vector<32, 16> with half_t reads 16 bytes (8 elements)
// For small buffers (less than 16 bytes), skip this transformation
if
(
store
->
buffer
.
defined
())
{
const
auto
&
buffer_shape
=
store
->
buffer
->
shape
;
if
(
buffer_shape
.
size
()
==
1
)
{
if
(
auto
*
int_shape
=
buffer_shape
[
0
].
as
<
IntImmNode
>
())
{
int
extent
=
int_shape
->
value
;
int
dtype_bytes
=
load
->
dtype
.
bytes
();
// ds_read_vector<32,16> with half_t reads 16 bytes minimum
// For buffers smaller than what ds_read_vector needs, skip
if
(
extent
*
dtype_bytes
<
16
)
{
return
StmtMutator
::
VisitStmt_
(
store
);
}
}
}
}
// Analyze the load pattern to determine which ds_read to use
return
InjectDSRead
(
store
,
load
);
}
return
StmtMutator
::
VisitStmt_
(
store
);
}
private:
// PrimExpr VisitExpr_(const CallNode *op) {
// Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
// if (call->op.same_as(builtin::tvm_access_ptr())) {
// return RewriteBufferAccess(call, {1});
// }
// return call;
// }
/*!
* \brief Check if any index expression contains a Ramp (vectorized) expression
*/
bool
HasVectorizedIndices
(
const
Array
<
PrimExpr
>&
indices
)
{
for
(
const
auto
&
idx
:
indices
)
{
if
(
idx
.
as
<
RampNode
>
())
{
return
true
;
}
}
return
false
;
}
Stmt
InjectDSRead
(
const
BufferStoreNode
*
store
,
const
BufferLoadNode
*
load
)
{
const
Buffer
&
shared_buf
=
load
->
buffer
;
const
Buffer
&
local_buf
=
store
->
buffer
;
// Analyze indices to determine the byte offset
// PrimExpr offset = load->indices.size() > 0 ? load->indices[0] : make_zero(DataType::UInt(0));
// Calculate buffer size in bytes
int
buffer_bytes
=
0
;
if
(
local_buf
.
defined
()
&&
local_buf
->
shape
.
size
()
==
1
)
{
if
(
auto
*
int_shape
=
local_buf
->
shape
[
0
].
as
<
IntImmNode
>
())
{
int
num_elements
=
int_shape
->
value
;
int
dtype_bytes
=
local_buf
->
dtype
.
bytes
();
buffer_bytes
=
num_elements
*
dtype_bytes
;
}
}
// Determine which ds_read to use based on buffer size
// ds_read_b64 loads 8 bytes (64 bits) = 1 element for half_t, 2 for float32
// ds_read_m32x16_b16 loads 32 bytes (256 bits)
int
dtype_bits
=
local_buf
->
dtype
.
bits
();
int
m
=
16
;
// For buffer < 16 bytes, use single ds_read_b64 (M=32, N=1)
// For buffer >= 16 bytes, use double ds_read_b64 (M=32, N=16)
// ds_read_b64 reads 8 bytes per call
int
n
=
(
buffer_bytes
>=
32
)
?
32
:
16
;
int
offset
=
0
;
return
EmitDSRead
(
local_buf
,
shared_buf
,
m
,
n
,
offset
);
}
Stmt
EmitDSRead
(
const
Buffer
&
local_buf
,
const
Buffer
&
shared_buf
,
int
m
,
int
n
,
int
offset
)
{
// ds_read_vector takes: (dst, shared_ptr, m, n, offset)
Array
<
PrimExpr
>
args
=
{
local_buf
->
data
,
// dst: local buffer data pointer
shared_buf
.
access_ptr
(
0
,
DataType
::
Handle
(),
1
,
0
),
// src: shared buffer data pointer
make_const
(
DataType
::
Int
(
32
),
m
),
make_const
(
DataType
::
Int
(
32
),
n
),
make_const
(
DataType
::
Int
(
32
),
offset
)
// byte_offset: offset into shared memory
};
Stmt
ds_read_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
ds_read_vector
(),
args
));
return
ds_read_stmt
;
}
};
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
InjectDSRead
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
const
PassContext
&
ctx
)
{
// Only apply to DCU targets
if
(
!
IsDCUTarget
(
m
))
{
return
f
;
}
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
DSReadInjector
()(
n
->
body
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectDSRead"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectDSRead"
,
InjectDSRead
);
}
}
// namespace tl
}
// namespace tvm
src/transform/inject_pipeline.cc
View file @
3852d58b
...
...
@@ -1057,6 +1057,9 @@ private:
int
stage
=
static_cast
<
int
>
(
pipeline_stages
[
i
]
->
value
);
bool
is_async
=
pipeline_async_stages
.
find
(
stage
)
!=
pipeline_async_stages
.
end
();
printf
(
"Block %s assigned to stage %d with order %d%s
\n
"
,
original_order
[
i
]
->
name_hint
.
c_str
(),
stage
,
static_cast
<
int
>
(
pipeline_orders
[
i
]
->
value
),
is_async
?
" (async)"
:
" sync"
);
PipelineAnnotation
stage_order
{
stage
,
/*order=*/
static_cast
<
int
>
(
pipeline_orders
[
i
]
->
value
),
is_async
,
...
...
tilelang/engine/phase.py
View file @
3852d58b
...
...
@@ -262,6 +262,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
# Inject ds_read for shared to register memory copy on DCU
mod
=
tilelang
.
transform
.
InjectDSRead
()(
mod
)
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
mod
=
tilelang
.
transform
.
MakePackedAPI
()(
mod
)
...
...
tilelang/env.py
View file @
3852d58b
...
...
@@ -237,7 +237,8 @@ class Environment:
# Kernel selection options
# Default to GEMM v2; set to "1"/"true"/"yes"/"on" to force v1
TILELANG_USE_GEMM_V1
=
EnvVar
(
"TILELANG_USE_GEMM_V1"
,
"1"
)
# TILELANG_USE_GEMM_V1 = EnvVar("TILELANG_USE_GEMM_V1", "1")
TILELANG_USE_GEMM_V1
=
EnvVar
(
"TILELANG_USE_GEMM_V1"
,
"0"
)
# Auto-tuning settings
TILELANG_AUTO_TUNING_DISABLE_CACHE
=
EnvVar
(
"TILELANG_AUTO_TUNING_DISABLE_CACHE"
,
"0"
)
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
3852d58b
...
...
@@ -769,18 +769,14 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_m
*
warp_rows
+
i
,
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_shared_buf
[
l
,
r
,
row
,
col
]
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
print
(
self
.
a_preshuffle
)
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row
s
+
i
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_
shared_buf
[
l
,
r
,
row
,
col
]
l
,
r
=
(
warp_m
*
warp_row
_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_
buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
(
_warp_ldmatrix_a_global
(
A_local_buf
,
A_buf
,
ki
,
thread_binding
,
rk
)
...
...
@@ -845,19 +841,19 @@ class MatrixCorePreshuffleIntrinEmitter(MatrixCoreIntrinEmitter):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_n
*
warp_col
s
+
j
,
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_col
_tiles
+
j
*
micro_size_y
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
else
:
for
j
in
T
.
serial
(
warp_cols
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_b
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
rk
*
(
chunk
//
micro_size_k
)
+
ki
,
warp_n
*
warp_col
s
+
j
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_n
*
warp_col
_tiles
+
j
*
micro_size_y
,
)
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
shared_buf
[
l
,
r
,
row
,
col
]
B_local_buf
[
j
*
k_pack
*
local_size_b
+
local_id
]
=
B_
buf
[
B_base0
+
l
+
row
,
B_base1
+
r
+
col
]
return
(
_warp_ldmatrix_b_global
(
B_local_buf
,
B_buf
,
ki
,
thread_binding
,
rk
)
...
...
tilelang/language/builtin.py
View file @
3852d58b
...
...
@@ -89,6 +89,44 @@ def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = N
raise
TypeError
(
"T.__ldg expects a BufferLoad or a Buffer."
)
def
ds_read_vector
(
dst
:
tir
.
Var
,
shared_ptr
:
tir
.
Var
,
m
:
int
,
n
:
int
,
offset
:
int
)
->
Call
:
"""
Load from shared memory using ds_read_b64 instruction.
This is a vectorized load instruction on AMD DCU that loads 64 bits (8 bytes)
from shared memory at the specified byte offset.
It writes 8 bytes to dst from shared memory at byte_offset.
This is a vectorized load instruction on AMD DCU that loads a 32x16 matrix
of half (16-bit) values with hardware-managed bank conflict avoidance.
Load from shared memory using ds_read_m32x16_b16 instruction.
The ds_read_vector intrinsic has signature:
ds_read_vector<M,N,offset>(float4 & dst, int lds_base_ptr)
Args:
dst: Destination pointer (register/local buffer).
lds_base_ptr: Source pointer (shared memory buffer data).
M: Number of columns in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
N: Number of rows in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
offset: address offset into shared memory.
Returns:
Call: A TIR call intrinsic for the ds_read_b64 instruction.
"""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ds_read_vector"
),
dst
,
shared_ptr
,
m
,
n
,
offset
)
def
get_mbarrier
(
*
args
):
"""Retrieve a memory barrier operation.
...
...
tilelang/language/gemm_op.py
View file @
3852d58b
...
...
@@ -139,6 +139,7 @@ def gemm_v1(
mbar
:
tir
.
Buffer
|
None
=
None
,
):
"""GEMM v1: use op tl.gemm."""
# print("Using GEMM v1")
return
_gemm_impl
(
"tl.tileop.gemm"
,
A
,
...
...
@@ -168,6 +169,7 @@ def gemm_v2(
mbar
:
tir
.
Buffer
|
None
=
None
,
):
"""GEMM v2: use op tl.gemm_py."""
print
(
"Using GEMM v2"
)
return
_gemm_impl
(
"tl.tileop.gemm_py"
,
A
,
...
...
tilelang/tileop/gemm/__init__.py
View file @
3852d58b
...
...
@@ -15,15 +15,17 @@ from .gemm_mmac import GemmMMAC
from
tilelang
import
_ffi_api
from
tilelang.utils.target
import
target_is_volta
print
(
"tileop gemm init..."
)
@
tvm_ffi
.
register_global_func
(
"tl.gemm_py.infer_layout"
)
def
gemm_py_infer_layout
(
gemm_py
:
GemmMMA
,
target
:
Target
,
thread_bounds
:
Range
):
print
(
"tileop gemm infer_layout"
)
thread_nums
=
thread_bounds
.
extent
return
gemm_py
.
infer_layout
(
target
,
thread_nums
)
@
tvm_ffi
.
register_global_func
(
"tl.gemm_py.lower"
)
def
gemm_py_lower
(
gemm_py
:
GemmMMA
,
layout_map
,
target
:
Target
,
thread_bounds
:
Range
,
thread_var
:
tir
.
Var
):
print
(
"tileop gemm lower"
)
thread_nums
=
thread_bounds
.
extent
stmt
=
gemm_py
.
lower
(
layout_map
,
target
,
thread_nums
,
thread_var
)
return
stmt
...
...
@@ -140,12 +142,14 @@ class GemmPy(Node, Scriptable):
def
infer_layout
(
self
,
target
:
Target
,
thread_nums
:
int
):
"""Infer the layout for the GEMM operation based on target architecture."""
print
(
f
"GemmPy infer_layout Target:
{
target
}
, thread_nums:
{
thread_nums
}
"
)
gemm_inst
=
self
.
_select_gemm_instruction
(
thread_nums
,
target
)
impl_class
=
self
.
_get_implementation_class
(
gemm_inst
,
target
)
return
impl_class
(
self
).
infer_layout
(
target
,
thread_nums
)
def
lower
(
self
,
layout_map
:
dict
,
target
:
Target
,
thread_nums
:
int
,
thread_var
:
tir
.
Var
):
"""Lower the GEMM operation to TIR statements based on target architecture."""
print
(
f
"GemmPy lower Target:
{
target
}
, thread_nums:
{
thread_nums
}
"
)
gemm_inst
=
self
.
_select_gemm_instruction
(
thread_nums
,
target
)
impl_class
=
self
.
_get_implementation_class
(
gemm_inst
,
target
)
return
impl_class
(
self
).
lower
(
layout_map
,
target
,
thread_nums
,
thread_var
)
...
...
@@ -181,6 +185,7 @@ class GemmPy(Node, Scriptable):
NotImplementedError: If the instruction type is not supported
ValueError: If the instruction type is unknown
"""
print
(
f
"_get_implementation_class Target:
{
target
}
"
)
if
gemm_inst
.
is_mma
():
if
target_is_volta
(
target
):
return
GemmMMASm70
...
...
tilelang/tileop/gemm/gemm_mfma.py
View file @
3852d58b
...
...
@@ -31,24 +31,28 @@ class GemmMFMA(GemmBase):
)
if
self
.
is_gemm_ss
():
print
(
"gemm_ss"
)
return
{
self
.
A
:
make_swizzled_layout
(
self
.
A
),
self
.
B
:
make_swizzled_layout
(
self
.
B
),
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
elif
self
.
is_gemm_sr
():
print
(
"gemm_sr"
)
return
{
self
.
A
:
make_swizzled_layout
(
self
.
A
),
self
.
B
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
B
,
matrix
=
"B"
),
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
elif
self
.
is_gemm_rs
():
print
(
"gemm_rs"
)
return
{
self
.
A
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
A
,
matrix
=
"A"
),
self
.
B
:
make_swizzled_layout
(
self
.
B
),
self
.
C
:
mfma_emitter
.
make_mfma_store_layout
(
self
.
C
),
}
elif
self
.
is_gemm_rr
():
print
(
"gemm_rr"
)
return
{
self
.
A
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
A
,
matrix
=
"A"
),
self
.
B
:
mfma_emitter
.
make_mfma_load_layout
(
self
.
B
,
matrix
=
"B"
),
...
...
@@ -101,6 +105,7 @@ class GemmMFMA(GemmBase):
assert
is_full_region
(
C_region
),
"Fragment output C must be a full region"
if
self
.
is_gemm_ss
():
print
(
"lower is_gemm_ss"
)
@
T
.
prim_func
def
_gemm_ssr
()
->
None
:
...
...
@@ -136,6 +141,7 @@ class GemmMFMA(GemmBase):
return
_Simplify
(
_gemm_ssr
,
inline_let
=
True
)
elif
self
.
is_gemm_sr
():
assert
is_full_region
(
B_region
),
"Fragment input B must be a full region"
print
(
"lower is_gemm_sr"
)
@
T
.
prim_func
def
_gemm_srr
()
->
None
:
...
...
@@ -167,6 +173,7 @@ class GemmMFMA(GemmBase):
return
_Simplify
(
_gemm_srr
,
inline_let
=
True
)
elif
self
.
is_gemm_rs
():
assert
is_full_region
(
A_region
),
"Fragment input A must be a full region"
print
(
"lower is_gemm_rs"
)
@
T
.
prim_func
def
_gemm_rsr
()
->
None
:
...
...
@@ -195,6 +202,7 @@ class GemmMFMA(GemmBase):
elif
self
.
is_gemm_rr
():
assert
is_full_region
(
A_region
),
"Fragment input A must be a full region"
assert
is_full_region
(
B_region
),
"Fragment input B must be a full region"
print
(
"lower is_gemm_rr"
)
@
T
.
prim_func
def
_gemm_rsr
()
->
None
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment