Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
32d0b3cb
Commit
32d0b3cb
authored
Apr 21, 2026
by
qisan
Browse files
Feats: support async_copy pass!
parent
a0ec0f57
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
718 additions
and
23 deletions
+718
-23
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+3
-3
src/op/builtin.cc
src/op/builtin.cc
+11
-0
src/op/copy.cc
src/op/copy.cc
+2
-1
src/op/gemm_py.cc
src/op/gemm_py.cc
+3
-1
src/op/parallel.cc
src/op/parallel.cc
+3
-2
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+55
-2
src/tl_templates/dcu_hip/copy.h
src/tl_templates/dcu_hip/copy.h
+100
-10
src/transform/lower_dcu_resource.cc
src/transform/lower_dcu_resource.cc
+369
-0
src/transform/vectorize_dcu_async_copy.cc
src/transform/vectorize_dcu_async_copy.cc
+132
-0
tilelang/contrib/rocm.py
tilelang/contrib/rocm.py
+10
-0
tilelang/engine/lower.py
tilelang/engine/lower.py
+4
-2
tilelang/engine/phase.py
tilelang/engine/phase.py
+15
-1
tilelang/language/ast/ir.py
tilelang/language/ast/ir.py
+2
-0
tilelang/tileop/gemm/gemm_mmac.py
tilelang/tileop/gemm/gemm_mmac.py
+1
-1
tilelang/transform/__init__.py
tilelang/transform/__init__.py
+8
-0
No files found.
examples/gemm/example_gemm.py
View file @
32d0b3cb
...
...
@@ -10,13 +10,13 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
128
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
3
):
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
0
):
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
)
...
...
@@ -27,7 +27,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
def
main
():
kernel
=
matmul
(
1024
,
1024
,
1024
,
128
,
128
,
32
)
kernel
=
matmul
(
1024
,
1024
,
1024
,
256
,
256
,
16
)
import
torch
...
...
src/op/builtin.cc
View file @
32d0b3cb
...
...
@@ -387,5 +387,16 @@ TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor)
TIR_DEFINE_TL_BUILTIN
(
__ldg
).
set_num_inputs
(
-
1
).
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kPure
));
//
TIR_DEFINE_TL_BUILTIN
(
dcu_async_copy
)
.
set_num_inputs
(
6
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
make_dcu_resource
)
.
set_num_inputs
(
2
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
}
// namespace tl
}
// namespace tvm
src/op/copy.cc
View file @
32d0b3cb
...
...
@@ -888,7 +888,7 @@ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower,
*/
Stmt
CopyNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
Target
target
=
T
.
target
;
printf
(
"Lowering CopyNode with target: %s
\n
"
,
target
->
str
().
c_str
());
using
namespace
tvm
::
transform
;
PassContext
pass_ctx
=
PassContext
::
Current
();
bool
disable_tma_lower
=
...
...
@@ -940,6 +940,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
*/
Stmt
CopyNode
::
LowerNormalCopy
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
printf
(
"Lowering normal copy for target: %s
\n
"
,
T
.
target
->
str
().
c_str
());
bool
is_cpu_target
=
T
.
target
->
GetTargetDeviceType
()
==
kDLCPU
;
auto
simt_loop
=
MakeSIMTLoop
(
analyzer
);
auto
fused_loop
=
Downcast
<
For
>
(
ParallelLoopFuser
::
Fuse
(
simt_loop
));
...
...
src/op/gemm_py.cc
View file @
32d0b3cb
...
...
@@ -290,6 +290,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
LayoutMap
results
;
if
(
const
auto
f
=
ffi
::
Function
::
GetGlobal
(
"tl.gemm_py.infer_layout"
))
{
printf
(
"GemmPyNode::InferLayout: calling tl.gemm_py.infer_layout
\n
"
);
results
=
Downcast
<
LayoutMap
>
(
(
*
f
)(
tvm
::
ffi
::
GetRef
<
GemmPy
>
(
this
),
T
.
target
,
T
.
thread_bounds
));
// Bind all fragment layouts with the provided thread range
...
...
@@ -303,7 +304,8 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
}
else
{
LOG
(
FATAL
)
<<
"No infer layout function found for gemm_py"
;
}
LOG
(
INFO
)
<<
"GemmPyNode::InferLayout results:"
;
LOG
(
INFO
)
<<
results
;
completed_
=
true
;
return
results
;
}
...
...
src/op/parallel.cc
View file @
32d0b3cb
...
...
@@ -242,7 +242,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel
level
)
const
{
if
(
loop_layout_
.
defined
())
return
{};
LOG
(
INFO
)
<<
"Inferring layout for T.Parallel loop with inference level "
<<
static_cast
<
int
>
(
level
)
<<
"...
\n
"
;
// Expand let bindings to find fragment buffer accesses
if
(
!
T
.
let_var_to_expr
.
empty
())
{
const_cast
<
ParallelOpNode
*>
(
this
)
->
ExpandLetBindings
(
T
.
let_var_to_expr
);
...
...
@@ -424,7 +425,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
LOG
(
FATAL
)
<<
msg
.
str
();
}
}
D
LOG
(
INFO
)
<<
"[compute_loop_layout_from_buffer] ... and get "
LOG
(
INFO
)
<<
"[compute_loop_layout_from_buffer] ... and get "
<<
result
->
DebugOutput
()
<<
'\n'
;
return
result
;
};
...
...
src/target/codegen_hip.cc
View file @
32d0b3cb
...
...
@@ -381,7 +381,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
case
32
:
{
if
(
t
.
is_scalar
())
{
os
<<
"int"
;
}
else
if
(
t
.
lanes
()
<=
4
)
{
}
else
if
(
t
.
lanes
()
==
4
)
{
os
<<
"int32x4_t"
;
}
else
if
(
t
.
lanes
()
<
4
)
{
os
<<
"int"
<<
t
.
lanes
();
}
else
if
(
t
.
lanes
()
<=
8
)
{
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
...
...
@@ -1088,7 +1090,58 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
// HIP doesn't need explicit register management like CUDA
// This is a no-op for HIP
return
;
}
else
{
}
else
if
(
op
->
op
.
same_as
(
Op
::
Get
(
"tl.make_dcu_resource"
)))
{
CHECK_EQ
(
op
->
args
.
size
(),
2
)
<<
"make_dcu_resource expects 2 arguments"
;
std
::
string
base_ptr
=
this
->
PrintExpr
(
op
->
args
[
0
]);
std
::
string
offset
;
if
(
const
RampNode
*
ramp
=
op
->
args
[
1
].
as
<
RampNode
>
())
{
offset
=
this
->
PrintExpr
(
ramp
->
base
);
}
else
{
offset
=
this
->
PrintExpr
(
op
->
args
[
1
]);
}
os
<<
"make_wave_buffer_resource("
<<
base_ptr
<<
" + ("
<<
offset
<<
"))"
;
}
else
if
(
op
->
op
.
same_as
(
Op
::
Get
(
"tl.dcu_async_copy"
)))
{
// 1. 提取模板参数 (IntImm 直接取值)
auto
get_int_const
=
[](
const
PrimExpr
&
e
)
->
int
{
if
(
const
auto
*
val
=
e
.
as
<
IntImmNode
>
())
return
static_cast
<
int
>
(
val
->
value
);
return
0
;
};
int
N
=
16
;
int
smem_offset
=
0
;
int
load_count
=
get_int_const
(
op
->
args
[
4
]);
int
i_sstride
=
get_int_const
(
op
->
args
[
5
]);
int
i_gstride
=
get_int_const
(
op
->
args
[
6
]);
int
k_gstride
=
get_int_const
(
op
->
args
[
7
]);
// 2. 将运行时参数打印到字符串中 (防止直接操作 stream 导致冲突)
std
::
string
dst_ptr
=
this
->
PrintExpr
(
op
->
args
[
0
]);
std
::
string
dst_off
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
src_res
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
src_off
=
this
->
PrintExpr
(
op
->
args
[
3
]);
// 3. 仿照范例进行流输出
this
->
PrintIndent
();
this
->
stream
<<
"cp_async_gs<"
<<
N
<<
", "
<<
smem_offset
<<
", "
<<
load_count
<<
", "
<<
i_sstride
<<
", "
<<
i_gstride
<<
", "
<<
k_gstride
<<
">("
;
// 拼接第一个参数:(char*)dst + dst_off
this
->
stream
<<
"((char*)"
<<
dst_ptr
<<
" + "
<<
dst_off
<<
"), "
;
// 拼接第二个参数:src_res
this
->
stream
<<
src_res
<<
", "
;
// 拼接第三个参数:src_off
this
->
stream
<<
src_off
<<
");
\n
"
;
}
else
{
printf
(
"[DEBUG VisitExpr_] Branch: CodeGenC::VisitExpr_ (fallback)
\n
"
);
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
...
...
src/tl_templates/dcu_hip/copy.h
View file @
32d0b3cb
...
...
@@ -38,6 +38,20 @@ __device__ void inc_m0(uint32_t m0_inc) {
asm
volatile
(
"s_add_u32 m0, %0, m0"
:
:
"n"
(
m0_inc
)
:
"memory"
);
}
#define UPDATE_WAVE_BUFFER_RESOURCE(res, stride) \
do { \
/* 1. 提取 64 位基地址,确保低位不进行符号位扩展 */
\
uint64_t __current_addr = (static_cast<uint64_t>((res).y) << 32) | \
(static_cast<uint32_t>((res).x)); \
\
/* 2. 增加步长 (自动处理类型提升) */
\
__current_addr += (stride); \
\
/* 3. 写回分量到 SGPRs */
\
(res).x = static_cast<int32_t>(__current_addr); \
(res).y = static_cast<int32_t>(__current_addr >> 32); \
} while (0)
namespace
tl
{
// AMDGPU automatically commit memory fence
...
...
@@ -72,20 +86,96 @@ CK_TILE_DEVICE void async_buffer_load_dword_v(void *smem, int32x4_t rsrc,
:
"memory"
);
}
template
<
int
N
>
TL_DEVICE
void
cp_async_gs
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
)
{
template
<
int
N
,
int
smem_offset
,
int
load_count
,
int
i_sstride
,
int
i_gstride
,
int
k_gstride
>
TL_DEVICE
void
cp_async_gs
(
void
*
lds_base_ptr
,
int32x4_t
res
,
int
offset
)
{
if
constexpr
(
N
==
16
)
{
*
(
uint4
*
)
lds_base_ptr
=
*
(
uint4
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
8
)
{
*
(
uint2
*
)
lds_base_ptr
=
*
(
uint2
*
)
global_base_ptr
;
}
else
if
constexpr
(
N
==
4
)
{
async_buffer_load_dword_v
(
lds_base_ptr
,
make_wave_buffer_resource
(((
int32_t
*
)
global_base_ptr
)
-
threadIdx
.
x
),
threadIdx
.
x
*
N
/*assume 4 bytes*/
);
if
constexpr
(
load_count
==
1
){
async_buffer_load_dwordx4_v
<
smem_offset
>
(
lds_base_ptr
,
res
,
offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
k_gstride
);
}
else
if
constexpr
(
load_count
==
2
){
async_buffer_load_dwordx4_v
<
smem_offset
>
(
lds_base_ptr
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
i_gstride
);
async_buffer_load_dwordx4_v
<
smem_offset
+
i_sstride
>
(
lds_base_ptr
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
k_gstride
-
i_gstride
);
}
else
if
constexpr
(
load_count
==
4
){
async_buffer_load_dwordx4_v
<
smem_offset
>
(
lds_base_ptr
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
i_gstride
);
async_buffer_load_dwordx4_v
<
smem_offset
+
i_sstride
>
(
lds_base_ptr
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
i_gstride
);
async_buffer_load_dwordx4_v
<
smem_offset
+
2
*
i_sstride
>
(
lds_base_ptr
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
i_gstride
);
async_buffer_load_dwordx4_v
<
smem_offset
+
3
*
i_sstride
>
(
lds_base_ptr
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
k_gstride
-
3
*
i_gstride
);
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
load_count
-
1
;
++
i
)
{
async_buffer_load_dwordx4_v
<
smem_offset
>
(
lds_base_ptr
+
i
*
i_sstride
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
i_gstride
);
}
async_buffer_load_dwordx4_v
<
smem_offset
>
(
lds_base_ptr
+
(
load_count
-
1
)
*
i_sstride
,
res
,
current_offset
);
UPDATE_WAVE_BUFFER_RESOURCE
(
res
,
k_gstride
-
(
load_count
-
1
)
*
i_gstride
);
}
}
else
{
not
implemented
;
}
}
template
<
int
N
>
// TL_DEVICE void cp_async_gs(void *lds_base_ptr, void *global_base_ptr) {
// if constexpr (N == 16) {
// *(uint4 *)lds_base_ptr = *(uint4 *)global_base_ptr;
// } else if constexpr (N == 8) {
// *(uint2 *)lds_base_ptr = *(uint2 *)global_base_ptr;
// } else if constexpr (N == 4) {
// async_buffer_load_dword_v(
// lds_base_ptr,
// make_wave_buffer_resource(((int32_t *)global_base_ptr) - threadIdx.x),
// threadIdx.x * N /*assume 4 bytes*/);
// }
// }
template
<
int
M
,
int
N
,
int
offset
>
TL_DEVICE
void
ds_read_vector
(
float4_
&
dst
,
uint32_t
lds_base_ptr
)
{
...
...
src/transform/lower_dcu_resource.cc
0 → 100644
View file @
32d0b3cb
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/memory.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/arith/analyzer.h>
#include <vector>
#include <unordered_map>
#include <unordered_set>
using
tvm
::
ffi
::
GetRef
;
using
tvm
::
ffi
::
make_object
;
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
ffi
::
Array
;
using
ffi
::
String
;
// ============================================================================
// 数据结构
// ============================================================================
struct
CopyInfo
{
Buffer
dst_buffer
;
Buffer
src_buffer
;
Array
<
PrimExpr
>
dst_indices
;
Array
<
PrimExpr
>
src_indices
;
Stmt
store_stmt
;
};
struct
CollectResult
{
std
::
vector
<
CopyInfo
>
copies
;
// 映射: Global Buffer Name -> DCU Resource Var (用于替换Store)
std
::
unordered_map
<
String
,
Var
>
global_to_res_var
;
// 映射: Shared Buffer Name -> 要注入的LetStmt绑定 (Var, PrimExpr)
// 这样我们就可以根据 shared buffer 的位置来决定注入点
std
::
unordered_map
<
String
,
std
::
pair
<
Var
,
PrimExpr
>>
shared_alloc_to_binding
;
const
StmtNode
*
inject_target
=
nullptr
;
};
class
VariableEliminator
:
public
tvm
::
tir
::
ExprMutator
{
public:
explicit
VariableEliminator
(
const
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>&
vars
)
:
vars_to_remove_
(
vars
)
{}
PrimExpr
VisitExpr_
(
const
tvm
::
tir
::
VarNode
*
op
)
override
{
if
(
vars_to_remove_
.
count
(
op
))
{
return
tvm
::
tir
::
make_zero
(
op
->
dtype
);
}
return
GetRef
<
PrimExpr
>
(
op
);
}
private:
const
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>&
vars_to_remove_
;
};
class
VariableKeeper
:
public
tvm
::
tir
::
ExprMutator
{
public:
explicit
VariableKeeper
(
const
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>&
keep_vars
)
:
keep_vars_
(
keep_vars
)
{}
PrimExpr
VisitExpr_
(
const
tvm
::
tir
::
VarNode
*
op
)
override
{
// 关键调试:打印每一个遇到的变量及其地址
if
(
keep_vars_
.
count
(
op
))
{
LOG
(
INFO
)
<<
"[KEEP] Found var in list: "
<<
op
->
name_hint
<<
" ("
<<
op
<<
")"
;
return
GetRef
<
PrimExpr
>
(
op
);
}
else
{
LOG
(
INFO
)
<<
"[ERASE] Var not in list: "
<<
op
->
name_hint
<<
" ("
<<
op
<<
")"
;
return
tvm
::
tir
::
make_zero
(
op
->
dtype
);
}
}
// 额外处理:防止 Load 节点中的变量丢失
PrimExpr
VisitExpr_
(
const
tvm
::
tir
::
BufferLoadNode
*
op
)
override
{
// 如果你的索引里嵌套了 BufferLoad,Load 本身不是 Var,
// 但它里面可能含有 Var。Mutator 默认会递归,但我们可以显式打印。
return
ExprMutator
::
VisitExpr_
(
op
);
}
private:
const
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>&
keep_vars_
;
};
// ============================================================================
// Phase 1: 收集拷贝信息 & 生成资源绑定
// ============================================================================
CollectResult
CollectResources
(
const
Stmt
&
body
)
{
class
Collector
:
public
StmtExprVisitor
{
public:
CollectResult
result
;
private:
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>
loop_vars_
;
std
::
vector
<
const
tvm
::
tir
::
StmtNode
*>
scope_stack_
;
// 追踪当前遍历的 AST 路径
bool
IsSharedScope
(
const
Buffer
&
buf
)
{
auto
s
=
buf
.
scope
();
return
s
==
"shared"
||
s
==
"shared.dyn"
;
}
bool
IsGlobalScope
(
const
Buffer
&
buf
)
{
auto
s
=
buf
.
scope
();
return
s
==
"global"
||
s
==
""
;
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
override
{
scope_stack_
.
push_back
(
op
);
if
(
op
->
attr_key
==
tvm
::
tir
::
attr
::
thread_extent
)
{
// 1. 获取 IterVar
auto
iv
=
op
->
node
.
as
<
tvm
::
tir
::
IterVarNode
>
();
const
std
::
string
&
tag
=
iv
->
thread_tag
;
// 2. 只有当 tag 包含 "threadIdx" 时才加入 (过滤掉 blockIdx)
// 比如: "threadIdx.x", "threadIdx.y", "threadIdx.z"
if
(
tag
.
find
(
"threadIdx"
)
!=
std
::
string
::
npos
)
{
tvm
::
tir
::
Var
thread_var
=
iv
->
var
;
loop_vars_
.
insert
(
thread_var
.
get
());
StmtExprVisitor
::
VisitStmt_
(
op
);
loop_vars_
.
erase
(
thread_var
.
get
());
}
else
{
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
scope_stack_
.
pop_back
();
}
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
override
{
scope_stack_
.
push_back
(
op
);
StmtExprVisitor
::
VisitStmt_
(
op
);
scope_stack_
.
pop_back
();
}
void
VisitStmt_
(
const
ForNode
*
op
)
override
{
scope_stack_
.
push_back
(
op
);
loop_vars_
.
insert
(
op
->
loop_var
.
get
());
StmtExprVisitor
::
VisitStmt_
(
op
);
loop_vars_
.
erase
(
op
->
loop_var
.
get
());
scope_stack_
.
pop_back
();
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Buffer
dst
=
op
->
buffer
;
if
(
IsSharedScope
(
dst
)
&&
op
->
value
.
defined
())
{
if
(
const
auto
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
())
{
Buffer
src
=
load
->
buffer
;
if
(
IsGlobalScope
(
src
))
{
if
(
result
.
inject_target
==
nullptr
)
{
// 从下往上回溯栈,寻找最内层的 thread_extent
for
(
int
i
=
scope_stack_
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
if
(
scope_stack_
[
i
]
->
IsInstance
<
AttrStmtNode
>
())
{
auto
attr
=
static_cast
<
const
AttrStmtNode
*>
(
scope_stack_
[
i
]);
if
(
attr
->
attr_key
==
tvm
::
tir
::
attr
::
thread_extent
)
{
// 找到了最内层的线程绑定。它里面的下一个节点(i+1)就是我们应该包裹的节点
if
(
i
+
1
<
scope_stack_
.
size
())
{
result
.
inject_target
=
scope_stack_
[
i
+
1
];
}
break
;
}
}
}
if
(
result
.
inject_target
==
nullptr
&&
!
scope_stack_
.
empty
())
{
for
(
const
auto
*
node
:
scope_stack_
)
{
if
(
node
->
IsInstance
<
ForNode
>
()
||
node
->
IsInstance
<
SeqStmtNode
>
())
{
result
.
inject_target
=
node
;
break
;
}
}
}
// 如果还是空,直接 fallback 到当前操作
if
(
result
.
inject_target
==
nullptr
)
result
.
inject_target
=
op
;
}
// 1. 记录拷贝
VariableKeeper
keeper
(
loop_vars_
);
tvm
::
arith
::
Analyzer
analyzer
;
Array
<
PrimExpr
>
for_var_only_indices
;
for
(
const
auto
&
idx
:
load
->
indices
)
{
PrimExpr
filtered
=
keeper
(
idx
);
for_var_only_indices
.
push_back
(
analyzer
.
Simplify
(
filtered
));
LOG
(
INFO
)
<<
"ONLY Index: "
<<
idx
;
}
CopyInfo
info
{
dst
,
src
,
op
->
indices
,
for_var_only_indices
,
GetRef
<
Stmt
>
(
op
)};
result
.
copies
.
push_back
(
info
);
// 2. 只有当没处理过这个 Global Buffer 时才生成 Binding
if
(
result
.
global_to_res_var
.
find
(
src
->
name
)
==
result
.
global_to_res_var
.
end
())
{
Var
var
(
src
->
name
+
"_dcu_res"
,
DataType
::
Int
(
32
,
4
));
VariableEliminator
eliminator
(
loop_vars_
);
tvm
::
arith
::
Analyzer
analyzer
;
Array
<
PrimExpr
>
base_indices
;
LOG
(
INFO
)
<<
loop_vars_
.
size
()
<<
" loop vars in context."
;
for
(
const
auto
*
var
:
loop_vars_
)
{
LOG
(
INFO
)
<<
"Loop Var: "
<<
var
->
name_hint
;
}
for
(
const
auto
&
idx
:
load
->
indices
)
{
// 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr
no_loops
=
eliminator
(
idx
);
// 化简出最终的基地址表达式
base_indices
.
push_back
(
analyzer
.
Simplify
(
no_loops
));
}
// ✅ 关键点:填充真实的地址信息 src->data (即 A.data)
Array
<
PrimExpr
>
args
;
args
.
push_back
(
src
->
data
);
// 先加 data
// 如果需要把 indices 的每个元素作为独立参数展开:
for
(
const
auto
&
idx
:
base_indices
)
{
args
.
push_back
(
idx
);
LOG
(
INFO
)
<<
"Clean Index: "
<<
idx
;
}
PrimExpr
val
=
Call
(
DataType
::
Int
(
32
,
4
),
Op
::
Get
(
"tl.make_dcu_resource"
),
args
);
result
.
global_to_res_var
[
src
->
name
]
=
var
;
// 将这个绑定关系和 destination 的 shared buffer 绑死
result
.
shared_alloc_to_binding
[
src
->
name
]
=
{
var
,
val
};
}
}
}
}
StmtExprVisitor
::
VisitStmt_
(
op
);
}
};
Collector
col
;
col
(
body
);
return
col
.
result
;
}
// ============================================================================
// Phase 2: 替换 BufferStore -> dcu_async_copy
// ============================================================================
class
StoreReplacer
:
public
StmtExprMutator
{
public:
static
Stmt
Run
(
Stmt
body
,
const
std
::
vector
<
CopyInfo
>&
copies
,
const
std
::
unordered_map
<
String
,
Var
>&
global_to_var
)
{
StoreReplacer
replacer
(
copies
,
global_to_var
);
return
replacer
(
std
::
move
(
body
));
}
private:
StoreReplacer
(
const
std
::
vector
<
CopyInfo
>&
copies
,
const
std
::
unordered_map
<
String
,
Var
>&
global_to_var
)
:
copies_
(
copies
),
global_to_var_
(
global_to_var
)
{}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
for
(
const
auto
&
copy
:
copies_
)
{
if
(
copy
.
store_stmt
.
same_as
(
GetRef
<
Stmt
>
(
op
)))
{
// Global 取 resource var (A_dcu_res)
Var
src_res
=
global_to_var_
.
at
(
copy
.
src_buffer
->
name
);
// Shared 取 data pointer (A_shared.data)
PrimExpr
dst_res
=
copy
.
dst_buffer
->
data
;
PrimExpr
copy_size
=
IntImm
(
DataType
::
Int
(
32
),
1
);
PrimExpr
predicate
=
Bool
(
true
);
return
Evaluate
(
Call
(
DataType
::
Int
(
32
),
Op
::
Get
(
"tl.dcu_async_copy"
),
{
dst_res
,
Flatten
(
copy
.
dst_indices
),
src_res
,
Flatten
(
copy
.
src_indices
),
copy_size
,
predicate
}));
}
}
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
PrimExpr
Flatten
(
const
Array
<
PrimExpr
>&
idx
)
{
if
(
idx
.
empty
())
return
IntImm
(
DataType
::
Int
(
32
),
0
);
if
(
idx
.
size
()
==
1
)
return
idx
[
0
];
PrimExpr
r
=
idx
[
0
];
for
(
size_t
i
=
1
;
i
<
idx
.
size
();
++
i
)
r
=
r
+
idx
[
i
];
return
r
;
}
const
std
::
vector
<
CopyInfo
>&
copies_
;
const
std
::
unordered_map
<
String
,
Var
>&
global_to_var_
;
};
// ============================================================================
// Phase 3: 根据 Shared Alloc 位置进行精准注入
// ============================================================================
class
ResourceInjector
:
public
tvm
::
tir
::
StmtExprMutator
{
public:
static
Stmt
Run
(
Stmt
body
,
const
std
::
unordered_map
<
String
,
std
::
pair
<
Var
,
PrimExpr
>>&
bindings
,
const
tvm
::
tir
::
StmtNode
*
target
)
{
if
(
!
target
||
bindings
.
empty
())
return
body
;
ResourceInjector
mutator
(
bindings
,
target
);
return
mutator
(
std
::
move
(
body
));
}
private:
ResourceInjector
(
const
std
::
unordered_map
<
String
,
std
::
pair
<
Var
,
PrimExpr
>>&
bindings
,
const
tvm
::
tir
::
StmtNode
*
target
)
:
bindings_
(
bindings
),
target_
(
target
)
{}
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
override
{
// 当我们遍历到刚才标记的那个 AST 节点时
if
(
stmt
.
get
()
==
target_
)
{
// 先向下遍历(保持 TVM Mutator 的习惯)
Stmt
new_stmt
=
StmtExprMutator
::
VisitStmt
(
stmt
);
// 在这个节点的外面套上所有的 LetStmt
for
(
const
auto
&
item
:
bindings_
)
{
Var
res_var
=
item
.
second
.
first
;
PrimExpr
init_expr
=
item
.
second
.
second
;
new_stmt
=
tvm
::
tir
::
LetStmt
(
res_var
,
init_expr
,
new_stmt
);
}
return
new_stmt
;
// 返回包裹好的新节点
}
return
StmtExprMutator
::
VisitStmt
(
stmt
);
}
std
::
unordered_map
<
String
,
std
::
pair
<
Var
,
PrimExpr
>>
bindings_
;
const
tvm
::
tir
::
StmtNode
*
target_
;
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc
LowerSharedGlobalCopy
(
PrimFunc
f
)
{
auto
*
n
=
f
.
CopyOnWrite
();
// 1. 收集信息并定位目标注入点
auto
res
=
CollectResources
(
n
->
body
);
if
(
res
.
copies
.
empty
())
return
f
;
// 【核心修改】:2. 先注入 LetStmt!
// 此时使用的 n->body 是原始 AST,res.inject_target 指针百分之百匹配。
Stmt
injected
=
ResourceInjector
::
Run
(
n
->
body
,
res
.
shared_alloc_to_binding
,
res
.
inject_target
);
// 3. 替换拷贝语句
// injected 是套了 LetStmt 的新 AST,但底层的 BufferStore 还是原来的,可以被正常替换。
Stmt
replaced
=
StoreReplacer
::
Run
(
injected
,
res
.
copies
,
res
.
global_to_res_var
);
// 4. 写回 PrimFunc
n
->
body
=
std
::
move
(
replaced
);
return
GetRef
<
PrimFunc
>
(
n
);
}
namespace
transform
{
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
LowerSharedGlobalCopy
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
PassContext
ctx
)
{
return
tl
::
LowerSharedGlobalCopy
(
std
::
move
(
f
));
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerSharedGlobalCopy"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerSharedGlobalCopy"
,
LowerSharedGlobalCopy
);
}
}
// namespace transform
}
// namespace tl
}
// namespace tvm
\ No newline at end of file
src/transform/vectorize_dcu_async_copy.cc
0 → 100644
View file @
32d0b3cb
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/arith/analyzer.h>
#include <vector>
using
namespace
tvm
::
tir
;
using
tvm
::
ffi
::
GetRef
;
using
tvm
::
ffi
::
make_object
;
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
ffi
::
Array
;
using
ffi
::
String
;
class
AsyncCopySimplifier
:
public
StmtExprMutator
{
public:
static
Stmt
Run
(
Stmt
stmt
)
{
AsyncCopySimplifier
mutator
;
return
mutator
(
std
::
move
(
stmt
));
}
private:
arith
::
Analyzer
analyzer_
;
Var
k_var_
;
PrimExpr
k_extent_
;
// 新增:记录 k 循环的次数
std
::
pair
<
PrimExpr
,
PrimExpr
>
ExtractStride
(
PrimExpr
expr
,
Var
var
)
{
if
(
!
var
.
defined
())
return
{
expr
,
make_zero
(
expr
.
dtype
())};
PrimExpr
base
=
tvm
::
tir
::
Substitute
(
expr
,
{{
var
,
make_zero
(
var
.
dtype
())}});
PrimExpr
plus_one
=
tvm
::
tir
::
Substitute
(
expr
,
{{
var
,
make_const
(
var
.
dtype
(),
1
)}});
PrimExpr
stride
=
analyzer_
.
Simplify
(
plus_one
-
base
);
return
{
analyzer_
.
Simplify
(
base
),
stride
};
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// 1. 记录 k 的信息
bool
is_k
=
(
op
->
loop_var
->
name_hint
==
"k"
);
if
(
is_k
)
{
k_var_
=
op
->
loop_var
;
k_extent_
=
op
->
extent
;
// 获取 k 的循环次数 (如 64)
}
// 2. 递归访问子节点
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
// 3. 处理 Async Copy 简化
if
(
op
->
kind
==
ForKind
::
kUnrolled
)
{
if
(
const
EvaluateNode
*
eval
=
body
.
as
<
EvaluateNode
>
())
{
if
(
const
CallNode
*
call
=
eval
->
value
.
as
<
CallNode
>
())
{
static
const
Op
&
dcu_copy_op
=
Op
::
Get
(
"tl.dcu_async_copy"
);
if
(
call
->
op
.
same_as
(
dcu_copy_op
))
{
Var
i_var
=
op
->
loop_var
;
PrimExpr
i_extent
=
op
->
extent
;
// 获取 i 的循环次数 (如 2)
auto
get_i_info
=
[
&
](
PrimExpr
offset
)
{
if
(
const
RampNode
*
ramp
=
offset
.
as
<
RampNode
>
())
{
auto
[
base
,
stride
]
=
ExtractStride
(
ramp
->
base
,
i_var
);
return
std
::
make_pair
(
base
,
stride
);
}
return
ExtractStride
(
offset
,
i_var
);
};
// 提取 i 的步长
auto
[
base_dst
,
i_stride_dst
]
=
get_i_info
(
call
->
args
[
1
]);
auto
[
base_src
,
i_stride_src
]
=
get_i_info
(
call
->
args
[
3
]);
// 提取 k 的步长 (从 base_src 继续解构)
auto
[
final_src_offset
,
k_stride_src
]
=
ExtractStride
(
base_src
,
k_var_
);
// 构造新的参数列表,包含循环次数
// 建议参数顺序:[dst, dst_off, src, src_off, size, i_extent, i_stride_dst, i_stride_src, k_stride_src]
// 这里的 size 保持原样 (如 8),i_extent 传入 2
Array
<
PrimExpr
>
new_args
=
{
call
->
args
[
0
],
// dst_buf
base_dst
,
// 基础 dst 偏移
call
->
args
[
2
],
// src_buf
final_src_offset
,
// 基础 src 偏移
i_extent
,
// i 循环次数
i_stride_dst
,
// i 的 dst 步长
i_stride_src
,
// i 的 src 步长
k_stride_src
// k 的 src 步长
};
return
Evaluate
(
Call
(
call
->
dtype
,
call
->
op
,
new_args
));
}
}
}
}
if
(
is_k
)
{
k_var_
=
Var
();
k_extent_
=
PrimExpr
();
}
if
(
body
.
same_as
(
op
->
body
))
return
GetRef
<
Stmt
>
(
op
);
auto
n
=
CopyOnWrite
(
op
);
n
->
body
=
std
::
move
(
body
);
return
Stmt
(
n
);
}
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc
SimplifyDCUAsyncCopy
(
PrimFunc
f
)
{
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
AsyncCopySimplifier
::
Run
(
std
::
move
(
n
->
body
));
return
GetRef
<
PrimFunc
>
(
n
);
}
namespace
transform
{
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
SimplifyDCUAsyncCopy
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
tvm
::
transform
::
PassContext
ctx
)
{
return
tl
::
SimplifyDCUAsyncCopy
(
std
::
move
(
f
));
};
return
tvm
::
tir
::
transform
::
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.SimplifyDCUAsyncCopy"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
tvm
::
ffi
::
reflection
::
GlobalDef
().
def
(
"tl.transform.SimplifyDCUAsyncCopy"
,
SimplifyDCUAsyncCopy
);
}
}
// namespace transform
}
// namespace tl
}
// namespace tvm
\ No newline at end of file
tilelang/contrib/rocm.py
View file @
32d0b3cb
...
...
@@ -26,6 +26,7 @@ import tvm_ffi
from
tvm.base
import
py_str
import
tvm.runtime
import
tvm.target
from
tvm.target
import
Target
from
tvm.contrib
import
utils
...
...
@@ -286,3 +287,12 @@ def find_rocm_path():
if
os
.
path
.
exists
(
os
.
path
.
join
(
rocm_path
,
"bin/hipcc"
)):
return
rocm_path
raise
RuntimeError
(
"Cannot find ROCm path"
)
def
is_dcu
(
target
:
Target
)
->
bool
:
if
target
.
kind
.
name
!=
"hip"
and
target
.
kind
.
name
!=
"rocm"
:
return
False
if
"mcpu"
in
target
.
attrs
:
mcpu
=
str
(
target
.
attrs
[
"mcpu"
])
return
mcpu
.
startswith
(
"gfx936"
)
return
False
\ No newline at end of file
tilelang/engine/lower.py
View file @
32d0b3cb
...
...
@@ -252,6 +252,7 @@ def lower(
func
=
func_or_mod
params
=
extrac_params
(
func
)
if
not
runtime_only
else
None
mod
=
tvm
.
IRModule
({
func
.
attrs
[
"global_symbol"
]:
func
})
print
(
mod
)
if
isinstance
(
target
,
str
):
target
=
determine_target
(
target
)
...
...
@@ -266,10 +267,11 @@ def lower(
# Before lowering, do semantic check
PreLowerSemanticCheck
(
mod
)
print
(
"1111111"
)
print
(
mod
)
# Phase 1: Lower and legalize the IR
mod
=
LowerAndLegalize
(
mod
,
target
)
# print(mod)
# Phase 2: Optimize the IR for the target
mod
=
OptimizeForTarget
(
mod
,
target
)
...
...
tilelang/engine/phase.py
View file @
32d0b3cb
...
...
@@ -4,6 +4,7 @@ from tvm.target import Target
import
tilelang
from
tilelang.transform
import
PassContext
from
tilelang.contrib.nvcc
import
have_tma
,
is_hopper
from
tilelang.contrib.rocm
import
is_dcu
def
allow_warp_specialized
(
pass_ctx
:
PassContext
|
None
=
None
,
target
:
Target
|
None
=
None
)
->
bool
:
...
...
@@ -69,6 +70,10 @@ def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
enabled
=
pass_ctx
.
config
.
get
(
tilelang
.
PassConfigKey
.
TL_LAYOUT_VISUALIZATION_ENABLE
,
False
)
return
enabled
def
dcu_async_copy_supported
(
target
:
Target
|
None
=
None
)
->
bool
:
return
is_dcu
(
target
)
def
get_layout_visual_formats
(
pass_ctx
:
PassContext
|
None
=
None
)
->
list
[
str
]:
if
pass_ctx
is
None
:
...
...
@@ -271,6 +276,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
# Inject ds_read for shared to register memory copy on DCU
mod
=
tilelang
.
transform
.
InjectDSRead
()(
mod
)
print
(
"222222222"
)
print
(
mod
)
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
...
...
@@ -281,5 +287,13 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock
mod
=
tilelang
.
transform
.
PersistThreadblock
()(
mod
)
print
(
"OptimizeForTarget"
)
print
(
mod
)
if
dcu_async_copy_supported
(
target
):
mod
=
tilelang
.
transform
.
LowerSharedGlobalCopy
()(
mod
)
print
(
"OptimizeForTarget2"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
SimplifyDCUAsyncCopy
()(
mod
)
print
(
"OptimizeForTarget3"
)
print
(
mod
)
return
mod
tilelang/language/ast/ir.py
View file @
32d0b3cb
...
...
@@ -1900,6 +1900,7 @@ tvm_mmac = _dtype_forward(_tir_op.tvm_mmac)
tvm_mfma_store
=
_dtype_forward
(
_tir_op
.
tvm_mfma_store
)
tvm_rdna_wmma
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma
)
tvm_rdna_wmma_store
=
_dtype_forward
(
_tir_op
.
tvm_rdna_wmma_store
)
make_dcu_resource
=
_dtype_forward
(
_tir_op
.
make_dcu_resource
)
broadcast
=
Broadcast
ramp
=
Ramp
...
...
@@ -2222,4 +2223,5 @@ __all__ = [
"CommReducer"
,
"Range"
,
"vscale"
,
"make_dcu_resource"
,
]
tilelang/tileop/gemm/gemm_mmac.py
View file @
32d0b3cb
from
.gemm_base
import
GemmBase
from
tilelang.layout
import
make_swizzled_layout
from
tilelang.layout
import
make_swizzled_layout
,
make_linear_layout
from
tilelang.intrinsics.mmac_macro_generator
import
(
MatrixCoreIntrinEmitter
,
)
...
...
tilelang/transform/__init__.py
View file @
32d0b3cb
...
...
@@ -552,3 +552,11 @@ def LayoutReducer():
The transform pass object produced by the FFI backend.
"""
return
_ffi_api
.
LayoutReducer
()
# type: ignore
def
LowerSharedGlobalCopy
():
"""DCUResourceRewriter"""
return
_ffi_api
.
LowerSharedGlobalCopy
()
# type: ignore
def
SimplifyDCUAsyncCopy
():
"""SimplifyDCUAsyncCopy"""
return
_ffi_api
.
SimplifyDCUAsyncCopy
()
# type: ignore
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment