Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
b14f201e
Commit
b14f201e
authored
May 07, 2026
by
qisan
Browse files
Feats: Add async, pipeline and ds_read
parent
44cc93c7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
201 additions
and
210 deletions
+201
-210
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+2
-4
src/tl_templates/dcu_hip/copy.h
src/tl_templates/dcu_hip/copy.h
+2
-3
src/transform/inject_blocal_layout_transform.cc
src/transform/inject_blocal_layout_transform.cc
+38
-95
src/transform/inject_mmac_fence.cc
src/transform/inject_mmac_fence.cc
+145
-25
src/transform/lower_dcu_resource.cc
src/transform/lower_dcu_resource.cc
+4
-47
src/transform/vectorize_dcu_async_copy.cc
src/transform/vectorize_dcu_async_copy.cc
+9
-22
tilelang/engine/phase.py
tilelang/engine/phase.py
+1
-14
No files found.
src/target/codegen_hip.cc
View file @
b14f201e
...
...
@@ -827,14 +827,12 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
func_name
+=
"_trans"
;
print_extern_call_stmt
(
func_name
,
2
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
ds_read_vector
())){
//ds_read_b64 %1, %2 offset:%3
// ds_read_m32x16_b16 %0, %1 offset:%2
std
::
string
dst
=
this
->
PrintExpr
(
op
->
args
[
0
]);
std
::
string
local_offset
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
lds_offset
=
this
->
PrintExpr
(
op
->
args
[
2
]);
os
<<
"tl::ds_read_vector("
os
<<
"tl::ds_read_vector(
*(float4_ *)(
"
<<
dst
<<
" + "
<<
local_offset
<<
", "
<<
"
)
, "
<<
lds_offset
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
wait_wgmma
()))
{
...
...
src/tl_templates/dcu_hip/copy.h
View file @
b14f201e
...
...
@@ -133,12 +133,11 @@ TL_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr,
// }
// }
TL_DEVICE
void
ds_read_vector
(
void
*
dst
,
uint32_t
lds_base_ptr
)
TL_DEVICE
void
ds_read_vector
(
float4_
&
dst
,
uint32_t
lds_base_ptr
)
{
asm
volatile
(
"ds_read_m32x16_b16 %0, %1 offset:0
\n\t
"
:
"+v"
(
dst
)
:
"v"
(
lds_base_ptr
),
:
"memory"
);
:
"v"
(
lds_base_ptr
));
}
// template <int M, int N, int offset>
...
...
src/transform/inject_blocal_layout_transform.cc
View file @
b14f201e
...
...
@@ -57,13 +57,15 @@ class BLocalLayoutTransformer : public StmtExprMutator {
int
expand_
;
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// 只处理 serial 外层循环
// 1. 先递归处理子节点(重要:确保处理了嵌套的 For 或 Attr)
Stmt
new_body
=
this
->
VisitStmt
(
op
->
body
);
// 2. 检查当前循环是否是目标循环
// 即使 body 变了,我们也尝试看看能不能在这个 loop 层级做变换
auto
store
=
new_body
.
as
<
BufferStoreNode
>
();
if
(
op
->
kind
!=
ForKind
::
kSerial
)
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
// 判断是否是 B_local 写循环
auto
store
=
op
->
body
.
as
<
BufferStoreNode
>
();
if
(
!
store
)
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
...
...
@@ -79,7 +81,6 @@ class BLocalLayoutTransformer : public StmtExprMutator {
int64_t
new_extent
=
old_extent
/
expand_
;
// 修改循环范围
For
new_for
=
For
(
op
->
loop_var
,
op
->
min
,
...
...
@@ -94,100 +95,42 @@ class BLocalLayoutTransformer : public StmtExprMutator {
std
::
string
name
=
buffer
->
name
;
return
name
.
find
(
"B_local"
)
!=
std
::
string
::
npos
;
}
Stmt
MutateStore
(
const
BufferStoreNode
*
store
,
const
Var
&
loop_var
)
{
Array
<
PrimExpr
>
new_indices
=
store
->
indices
;
PrimExpr
new_value
=
store
->
value
;
// 修改切片跨度:
// 原来 j*vec : j*vec+vec
// 改为 j*vec : j*vec*expand + vec
PrimExpr
idx
=
store
->
indices
[
0
];
//T.Ramp(j * 4, 1, 4) -> Ramp(j*8, 1, 4)
std
::
cout
<<
idx
<<
std
::
endl
;
// 解析 j*vec 结构
// 假设结构为 j * vec + const
// 不改 RHS
// PrimExpr value = store->value;
// 修改写入向量宽度
// 原 value 是 Ramp(base=j*4, stride=1, lanes=4)
// 匹配 j * stride
// Ramp(base=j*8, stride=1, lanes=8)
if
(
const
auto
*
ramp
=
idx
.
as
<
RampNode
>
())
{
PrimExpr
base
=
ramp
->
base
;
PrimExpr
stride
=
ramp
->
stride
;
int
old_lanes
=
ramp
->
lanes
.
as
<
IntImmNode
>
()
->
value
;
int
new_lanes
=
old_lanes
*
expand_
;
// 匹配 base = j * stride_val
if
(
const
auto
*
mul
=
base
.
as
<
MulNode
>
())
{
if
(
mul
->
a
.
same_as
(
loop_var
))
{
int64_t
old_stride
=
mul
->
b
.
as
<
IntImmNode
>
()
->
value
;
int64_t
new_stride
=
old_stride
*
expand_
;
PrimExpr
new_base
=
loop_var
*
make_const
(
DataType
::
Int
(
32
),
new_stride
);
new_indices
.
Set
(
0
,
Ramp
(
new_base
,
stride
,
new_lanes
));
}
else
if
(
mul
->
b
.
same_as
(
loop_var
))
{
int64_t
old_stride
=
mul
->
a
.
as
<
IntImmNode
>
()
->
value
;
int64_t
new_stride
=
old_stride
*
expand_
;
PrimExpr
new_base
=
make_const
(
DataType
::
Int
(
32
),
new_stride
)
*
loop_var
;
new_indices
.
Set
(
0
,
Ramp
(
new_base
,
stride
,
new_lanes
));
}
PrimExpr
UpdateIndexBase
(
PrimExpr
base
,
const
Var
&
loop_var
,
int
expand
)
{
if
(
const
auto
*
add
=
base
.
as
<
AddNode
>
())
{
return
UpdateIndexBase
(
add
->
a
,
loop_var
,
expand
)
+
UpdateIndexBase
(
add
->
b
,
loop_var
,
expand
);
}
else
if
(
const
auto
*
mul
=
base
.
as
<
MulNode
>
())
{
if
(
mul
->
a
.
same_as
(
loop_var
))
{
return
mul
->
a
*
(
mul
->
b
*
expand
);
}
else
if
(
mul
->
b
.
same_as
(
loop_var
))
{
return
(
mul
->
a
*
expand
)
*
mul
->
b
;
}
}
return
base
;
}
Stmt
MutateStore
(
const
BufferStoreNode
*
store
,
const
Var
&
loop_var
)
{
auto
n
=
tvm
::
ffi
::
make_object
<
BufferStoreNode
>
(
*
store
);
Array
<
PrimExpr
>
new_indices
=
store
->
indices
;
if
(
const
auto
*
ramp
=
store
->
indices
[
0
].
as
<
RampNode
>
())
{
PrimExpr
new_base
=
UpdateIndexBase
(
ramp
->
base
,
loop_var
,
expand_
);
int
new_lanes
=
ramp
->
lanes
.
as
<
IntImmNode
>
()
->
value
*
expand_
;
new_indices
.
Set
(
0
,
Ramp
(
new_base
,
ramp
->
stride
,
new_lanes
));
}
}
if
(
auto
*
load
=
new_value
.
as
<
BufferLoadNode
>
())
{
// BufferLoad with region access: B_shared[start : end]
// end - start = lanes,需要同步扩展
Array
<
PrimExpr
>
value_indices
=
load
->
indices
;
if
(
auto
*
old_ramp
=
load
->
indices
[
0
].
as
<
RampNode
>
())
{
PrimExpr
scalar_base
=
old_ramp
->
base
;
// 必须是 scalar
PrimExpr
stride
=
old_ramp
->
stride
;
//RHS 4 lane
int
old_lanes
=
old_ramp
->
lanes
.
as
<
IntImmNode
>
()
->
value
;
//RHS 8 lane
int
new_lanes
=
old_lanes
*
expand_
;
value_indices
.
Set
(
0
,
Ramp
(
scalar_base
,
stride
,
new_lanes
)
);
new_value
=
BufferLoad
(
load
->
buffer
,
value_indices
);
PrimExpr
new_value
=
store
->
value
;
if
(
const
auto
*
load
=
store
->
value
.
as
<
BufferLoadNode
>
())
{
if
(
const
auto
*
l_ramp
=
load
->
indices
[
0
].
as
<
RampNode
>
())
{
Array
<
PrimExpr
>
v_indices
=
load
->
indices
;
int
v_new_lanes
=
l_ramp
->
lanes
.
as
<
IntImmNode
>
()
->
value
*
expand_
;
v_indices
.
Set
(
0
,
Ramp
(
l_ramp
->
base
,
l_ramp
->
stride
,
v_new_lanes
));
new_value
=
BufferLoad
(
load
->
buffer
,
v_indices
);
}
}
}
return
BufferStore
(
store
->
buffer
,
new_value
,
new_indices
);
return
BufferStore
(
store
->
buffer
,
new_value
,
new_indices
);
}
};
...
...
src/transform/inject_mmac_fence.cc
View file @
b14f201e
...
...
@@ -13,7 +13,6 @@ namespace tl {
using
ffi
::
Array
;
using
namespace
tir
;
// 1. 辅助类:统计 Shared -> Register 的加载量
class
LoadCounter
:
public
StmtExprVisitor
{
public:
int
total_loads
=
0
;
...
...
@@ -39,11 +38,147 @@ public:
}
ExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitExpr_
(
const
CallNode
*
op
)
override
{
std
::
string
func_name
=
""
;
if
(
auto
opt_op
=
op
->
op
.
as
<
OpNode
>
())
{
func_name
=
opt_op
->
name
;
}
else
if
(
auto
global_var
=
op
->
op
.
as
<
GlobalVarNode
>
())
{
func_name
=
global_var
->
name_hint
;
}
if
(
func_name
.
find
(
"ds_read"
)
!=
std
::
string
::
npos
)
{
total_loads
+=
current_multiplier
;
}
ExprVisitor
::
VisitExpr_
(
op
);
}
private:
bool
IsSharedMem
(
const
Buffer
&
buf
)
{
std
::
string
scope
=
buf
.
scope
();
std
::
string
name
=
buf
->
name
;
return
(
scope
==
"shared"
||
name
.
find
(
"shared"
)
!=
std
::
string
::
npos
||
name
.
find
(
"shmem"
)
!=
std
::
string
::
npos
||
name
.
find
(
"LDS"
)
!=
std
::
string
::
npos
);
}
};
// 2. 核心 Mutator
namespace
{
bool
StmtContainsMMA
(
const
Stmt
&
stmt
)
{
bool
found
=
false
;
PostOrderVisit
(
stmt
,
[
&
found
](
const
ObjectRef
&
node
)
{
if
(
const
CallNode
*
call
=
node
.
as
<
CallNode
>
())
{
std
::
string
op_name
=
""
;
if
(
const
OpNode
*
op
=
call
->
op
.
as
<
OpNode
>
())
{
op_name
=
op
->
name
;
}
else
if
(
const
GlobalVarNode
*
gv
=
call
->
op
.
as
<
GlobalVarNode
>
())
{
op_name
=
gv
->
name_hint
;
}
if
(
op_name
.
find
(
"mmac"
)
!=
std
::
string
::
npos
||
op_name
.
find
(
"mma"
)
!=
std
::
string
::
npos
)
{
found
=
true
;
}
}
});
return
found
;
}
void
ScanStmtDefault
(
const
Stmt
&
s
,
std
::
vector
<
Stmt
>*
fence_targets
);
void
ScanSeqStmt
(
const
SeqStmtNode
*
op
,
std
::
vector
<
Stmt
>*
fence_targets
)
{
int
pending
=
0
;
for
(
size_t
i
=
0
;
i
<
op
->
seq
.
size
();
++
i
)
{
const
Stmt
&
stmt
=
op
->
seq
[
i
];
if
(
StmtContainsMMA
(
stmt
))
{
if
(
pending
>
0
)
{
fence_targets
->
push_back
(
stmt
);
pending
=
0
;
}
ScanStmtDefault
(
stmt
,
fence_targets
);
}
else
{
LoadCounter
counter
;
counter
(
stmt
);
pending
+=
counter
.
total_loads
;
ScanStmtDefault
(
stmt
,
fence_targets
);
}
}
}
void
ScanStmtDefault
(
const
Stmt
&
s
,
std
::
vector
<
Stmt
>*
fence_targets
)
{
if
(
const
auto
*
seq
=
s
.
as
<
SeqStmtNode
>
())
{
ScanSeqStmt
(
seq
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
AttrStmtNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
LetStmtNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
IfThenElseNode
>
())
{
ScanStmtDefault
(
op
->
then_case
,
fence_targets
);
if
(
op
->
else_case
)
{
ScanStmtDefault
(
op
->
else_case
.
value
(),
fence_targets
);
}
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
ForNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
WhileNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
AllocateNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
AllocateConstNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
DeclBufferNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
BufferRealizeNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
AssertStmtNode
>
())
{
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
BlockNode
>
())
{
if
(
op
->
init
.
defined
())
{
ScanStmtDefault
(
op
->
init
.
value
(),
fence_targets
);
}
ScanStmtDefault
(
op
->
body
,
fence_targets
);
return
;
}
if
(
const
auto
*
op
=
s
.
as
<
BlockRealizeNode
>
())
{
ScanStmtDefault
(
op
->
block
,
fence_targets
);
return
;
}
}
Stmt
ComputeGlobalLastFenceMMAStmt
(
const
Stmt
&
root
)
{
std
::
vector
<
Stmt
>
fence_targets
;
ScanStmtDefault
(
root
,
&
fence_targets
);
if
(
fence_targets
.
empty
())
{
return
Stmt
();
}
return
fence_targets
.
back
();
}
}
class
MMABarrierMutator
:
public
StmtExprMutator
{
public:
explicit
MMABarrierMutator
(
const
Stmt
&
root_body
)
:
global_last_fence_mma_
(
ComputeGlobalLastFenceMMAStmt
(
root_body
))
{}
bool
ContainsMMA
(
const
Stmt
&
stmt
)
{
bool
found
=
false
;
PostOrderVisit
(
stmt
,
[
&
found
](
const
ObjectRef
&
node
)
{
...
...
@@ -64,23 +199,6 @@ public:
}
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
override
{
// --- 步骤 1: 预扫描,确定最后一个需要插入 Fence 的位置 ---
int
last_fence_idx
=
-
1
;
int
temp_pending_count
=
0
;
for
(
size_t
i
=
0
;
i
<
op
->
seq
.
size
();
++
i
)
{
if
(
ContainsMMA
(
op
->
seq
[
i
]))
{
if
(
temp_pending_count
>
0
)
{
last_fence_idx
=
static_cast
<
int
>
(
i
);
temp_pending_count
=
0
;
// 模拟重置
}
}
else
{
LoadCounter
counter
;
counter
(
op
->
seq
[
i
]);
temp_pending_count
+=
counter
.
total_loads
;
}
}
// --- 步骤 2: 实际构造新的 Sequence ---
Array
<
Stmt
>
new_seq
;
int
pending_load_count
=
0
;
...
...
@@ -89,16 +207,16 @@ public:
if
(
ContainsMMA
(
stmt
))
{
if
(
pending_load_count
>
0
)
{
// 判断是否是该序列中最后一个 Fence
int
fence_val
=
(
static_cast
<
int
>
(
i
)
==
last_fence_idx
)
?
0
:
pending_load_count
;
int
fence_val
=
(
global_last_fence_mma_
.
defined
()
&&
stmt
.
same_as
(
global_last_fence_mma_
))
?
0
:
pending_load_count
;
Array
<
PrimExpr
>
args
=
{
Integer
(
fence_val
)};
// 构造 Fence
auto
fence_call
=
Call
(
DataType
::
Void
(),
Op
::
Get
(
"tl.async_gld_fence"
),
args
);
new_seq
.
push_back
(
Evaluate
(
fence_call
));
// 构造 Barrier
auto
barrier_call
=
Call
(
DataType
::
Void
(),
Op
::
Get
(
"tl.wave_barrier"
),
{});
new_seq
.
push_back
(
Evaluate
(
barrier_call
));
...
...
@@ -114,15 +232,17 @@ public:
}
return
SeqStmt
(
new_seq
);
}
private:
Stmt
global_last_fence_mma_
;
};
// 3. Pass 包装
namespace
transform
{
using
namespace
tir
::
transform
;
Pass
InsertAsyncMMAFence
()
{
auto
pass_func
=
[](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
*
n
=
f
.
CopyOnWrite
();
MMABarrierMutator
mutator
;
MMABarrierMutator
mutator
(
n
->
body
)
;
n
->
body
=
mutator
(
n
->
body
);
return
f
;
};
...
...
src/transform/lower_dcu_resource.cc
View file @
b14f201e
...
...
@@ -20,9 +20,6 @@ using namespace tir;
using
ffi
::
Array
;
using
ffi
::
String
;
// ============================================================================
// 数据结构
// ============================================================================
struct
CopyInfo
{
Buffer
dst_buffer
;
Buffer
src_buffer
;
...
...
@@ -33,10 +30,7 @@ struct CopyInfo {
struct
CollectResult
{
std
::
vector
<
CopyInfo
>
copies
;
// 映射: Global Buffer Name -> DCU Resource Var (用于替换Store)
std
::
unordered_map
<
String
,
Var
>
global_to_res_var
;
// 映射: Shared Buffer Name -> 要注入的LetStmt绑定 (Var, PrimExpr)
// 这样我们就可以根据 shared buffer 的位置来决定注入点
std
::
unordered_map
<
String
,
std
::
pair
<
Var
,
PrimExpr
>>
shared_alloc_to_binding
;
const
StmtNode
*
inject_target
=
nullptr
;
...
...
@@ -64,7 +58,6 @@ class VariableKeeper : public tvm::tir::ExprMutator {
:
keep_vars_
(
keep_vars
)
{}
PrimExpr
VisitExpr_
(
const
tvm
::
tir
::
VarNode
*
op
)
override
{
// 关键调试:打印每一个遇到的变量及其地址
if
(
keep_vars_
.
count
(
op
))
{
return
GetRef
<
PrimExpr
>
(
op
);
}
else
{
...
...
@@ -72,10 +65,7 @@ class VariableKeeper : public tvm::tir::ExprMutator {
}
}
// 额外处理:防止 Load 节点中的变量丢失
PrimExpr
VisitExpr_
(
const
tvm
::
tir
::
BufferLoadNode
*
op
)
override
{
// 如果你的索引里嵌套了 BufferLoad,Load 本身不是 Var,
// 但它里面可能含有 Var。Mutator 默认会递归,但我们可以显式打印。
return
ExprMutator
::
VisitExpr_
(
op
);
}
...
...
@@ -83,9 +73,6 @@ class VariableKeeper : public tvm::tir::ExprMutator {
const
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>&
keep_vars_
;
};
// ============================================================================
// Phase 1: 收集拷贝信息 & 生成资源绑定
// ============================================================================
CollectResult
CollectResources
(
const
Stmt
&
body
)
{
class
Collector
:
public
StmtExprVisitor
{
public:
...
...
@@ -94,7 +81,7 @@ CollectResult CollectResources(const Stmt& body) {
private:
bool
in_async
{
false
};
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>
loop_vars_
;
std
::
vector
<
const
tvm
::
tir
::
StmtNode
*>
scope_stack_
;
// 追踪当前遍历的 AST 路径
std
::
vector
<
const
tvm
::
tir
::
StmtNode
*>
scope_stack_
;
bool
IsSharedScope
(
const
Buffer
&
buf
)
{
auto
s
=
buf
.
scope
();
return
s
==
"shared"
||
s
==
"shared.dyn"
;
...
...
@@ -107,7 +94,6 @@ CollectResult CollectResources(const Stmt& body) {
void
VisitStmt_
(
const
AttrStmtNode
*
attr
)
override
{
scope_stack_
.
push_back
(
attr
);
if
(
attr
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
// 1. 获取 IterVar
auto
iv
=
attr
->
node
.
as
<
tvm
::
tir
::
IterVarNode
>
();
const
std
::
string
&
tag
=
iv
->
thread_tag
;
...
...
@@ -119,7 +105,6 @@ CollectResult CollectResources(const Stmt& body) {
loop_vars_
.
erase
(
thread_var
.
get
());
}
else
{
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor
::
VisitStmt_
(
attr
);
}
...
...
@@ -173,7 +158,6 @@ CollectResult CollectResources(const Stmt& body) {
if
(
scope_stack_
[
i
]
->
IsInstance
<
AttrStmtNode
>
())
{
auto
attr
=
static_cast
<
const
AttrStmtNode
*>
(
scope_stack_
[
i
]);
if
(
attr
->
attr_key
==
tvm
::
tir
::
attr
::
thread_extent
)
{
// 找到了最内层的线程绑定。它里面的下一个节点(i+1)就是我们应该包裹的节点
if
(
i
+
1
<
scope_stack_
.
size
())
{
result
.
inject_target
=
scope_stack_
[
i
+
1
];
}
...
...
@@ -189,12 +173,10 @@ CollectResult CollectResources(const Stmt& body) {
}
}
}
// 如果还是空,直接 fallback 到当前操作
if
(
result
.
inject_target
==
nullptr
)
result
.
inject_target
=
op
;
}
// 1. 记录拷贝
VariableKeeper
keeper
(
loop_vars_
);
tvm
::
arith
::
Analyzer
analyzer
;
Array
<
PrimExpr
>
for_var_only_indices
;
...
...
@@ -206,7 +188,6 @@ CollectResult CollectResources(const Stmt& body) {
CopyInfo
info
{
dst
,
src
,
op
->
indices
,
for_var_only_indices
,
GetRef
<
Stmt
>
(
op
)};
result
.
copies
.
push_back
(
info
);
// 2. 只有当没处理过这个 Global Buffer 时才生成 Binding
if
(
result
.
global_to_res_var
.
find
(
src
->
name
)
==
result
.
global_to_res_var
.
end
())
{
Var
var
(
src
->
name
+
"_dcu_res"
,
DataType
::
Int
(
32
,
4
));
...
...
@@ -214,17 +195,13 @@ CollectResult CollectResources(const Stmt& body) {
tvm
::
arith
::
Analyzer
analyzer
;
Array
<
PrimExpr
>
base_indices
;
for
(
const
auto
&
idx
:
load
->
indices
)
{
// 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr
no_loops
=
eliminator
(
idx
);
// 化简出最终的基地址表达式
base_indices
.
push_back
(
analyzer
.
Simplify
(
no_loops
));
}
// ✅ 关键点:填充真实的地址信息 src->data (即 A.data)
Array
<
PrimExpr
>
args
;
args
.
push_back
(
src
->
data
);
// 先加 data
args
.
push_back
(
src
->
data
);
// 如果需要把 indices 的每个元素作为独立参数展开:
for
(
const
auto
&
idx
:
base_indices
)
{
args
.
push_back
(
idx
);
}
...
...
@@ -232,7 +209,6 @@ CollectResult CollectResources(const Stmt& body) {
Op
::
Get
(
"tl.make_dcu_resource"
),
args
);
result
.
global_to_res_var
[
src
->
name
]
=
var
;
// 将这个绑定关系和 destination 的 shared buffer 绑死
result
.
shared_alloc_to_binding
[
src
->
name
]
=
{
var
,
val
};
}
}
...
...
@@ -247,9 +223,6 @@ CollectResult CollectResources(const Stmt& body) {
return
col
.
result
;
}
// ============================================================================
// Phase 2: 替换 BufferStore -> dcu_async_copy
// ============================================================================
class
StoreReplacer
:
public
StmtExprMutator
{
public:
static
Stmt
Run
(
Stmt
body
,
const
std
::
vector
<
CopyInfo
>&
copies
,
...
...
@@ -268,16 +241,14 @@ private:
auto
body
=
this
->
VisitStmt
(
attr
->
body
);
return
body
;
}
return
StmtMutator
::
VisitStmt_
(
attr
);
// ③ 其他属性:默认保留
return
StmtMutator
::
VisitStmt_
(
attr
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
for
(
const
auto
&
copy
:
copies_
)
{
if
(
copy
.
store_stmt
.
same_as
(
GetRef
<
Stmt
>
(
op
)))
{
// Global 取 resource var (A_dcu_res)
Var
src_res
=
global_to_var_
.
at
(
copy
.
src_buffer
->
name
);
// Shared 取 data pointer (A_shared.data)
PrimExpr
dst_res
=
copy
.
dst_buffer
->
data
;
PrimExpr
copy_size
=
IntImm
(
DataType
::
Int
(
32
),
1
);
...
...
@@ -305,9 +276,6 @@ private:
const
std
::
unordered_map
<
String
,
Var
>&
global_to_var_
;
};
// ============================================================================
// Phase 3: 根据 Shared Alloc 位置进行精准注入
// ============================================================================
class
ResourceInjector
:
public
tvm
::
tir
::
StmtExprMutator
{
public:
static
Stmt
Run
(
Stmt
body
,
...
...
@@ -324,18 +292,15 @@ private:
:
bindings_
(
bindings
),
target_
(
target
)
{}
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
override
{
// 当我们遍历到刚才标记的那个 AST 节点时
if
(
stmt
.
get
()
==
target_
)
{
// 先向下遍历(保持 TVM Mutator 的习惯)
Stmt
new_stmt
=
StmtExprMutator
::
VisitStmt
(
stmt
);
// 在这个节点的外面套上所有的 LetStmt
for
(
const
auto
&
item
:
bindings_
)
{
Var
res_var
=
item
.
second
.
first
;
PrimExpr
init_expr
=
item
.
second
.
second
;
new_stmt
=
tvm
::
tir
::
LetStmt
(
res_var
,
init_expr
,
new_stmt
);
}
return
new_stmt
;
// 返回包裹好的新节点
return
new_stmt
;
}
return
StmtExprMutator
::
VisitStmt
(
stmt
);
}
...
...
@@ -344,26 +309,18 @@ private:
const
tvm
::
tir
::
StmtNode
*
target_
;
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc
LowerSharedGlobalCopy
(
PrimFunc
f
)
{
auto
*
n
=
f
.
CopyOnWrite
();
// 收集信息
auto
res
=
CollectResources
(
n
->
body
);
if
(
res
.
copies
.
empty
()){
return
f
;
}
// 注入res声明
Stmt
injected
=
ResourceInjector
::
Run
(
n
->
body
,
res
.
shared_alloc_to_binding
,
res
.
inject_target
);
// 替换拷贝语句
Stmt
replaced
=
StoreReplacer
::
Run
(
injected
,
res
.
copies
,
res
.
global_to_res_var
);
// 写回
n
->
body
=
std
::
move
(
replaced
);
return
GetRef
<
PrimFunc
>
(
n
);
...
...
src/transform/vectorize_dcu_async_copy.cc
View file @
b14f201e
...
...
@@ -31,7 +31,6 @@ private:
PrimExpr
k_extent_
;
bool
in_unrolled_i_
=
false
;
// 通用的步长提取函数:从 expr 中提取指定 var 的步长,并返回剩余的 base
std
::
pair
<
PrimExpr
,
PrimExpr
>
ExtractStride
(
PrimExpr
expr
,
Var
var
)
{
if
(
!
var
.
defined
())
return
{
expr
,
make_zero
(
expr
->
dtype
)};
...
...
@@ -41,33 +40,26 @@ private:
return
{
analyzer_
.
Simplify
(
base
),
stride
};
}
// 核心重写逻辑
Stmt
RewriteAsyncCopy
(
const
CallNode
*
call
,
Var
i_var
,
PrimExpr
i_extent
)
{
// 1. 预处理:剥离 RampNode 获得基础偏移
PrimExpr
raw_dst_off
=
call
->
args
[
1
];
PrimExpr
raw_src_off
=
call
->
args
[
3
];
if
(
const
RampNode
*
r
=
raw_dst_off
.
as
<
RampNode
>
())
raw_dst_off
=
r
->
base
;
if
(
const
RampNode
*
r
=
raw_src_off
.
as
<
RampNode
>
())
raw_src_off
=
r
->
base
;
// 2. 提取 i 的步长
auto
[
base_i_dst
,
i_stride_dst
]
=
ExtractStride
(
raw_dst_off
,
i_var
);
auto
[
base_i_src
,
i_stride_src
]
=
ExtractStride
(
raw_src_off
,
i_var
);
// 3. 提取 k 的步长 (始终尝试从 base_i_src 中提取,无论是否存在 i 循环)
// 只要 k_var_ 在外层循环中定义了,这里就能提取出非 0 的步长
auto
[
final_src_offset
,
k_stride_src
]
=
ExtractStride
(
base_i_src
,
k_var_
);
// 构造最初要求的 8 个参数:
// [dst, dst_off, src, src_off, i_extent, i_stride_dst, i_stride_src, k_stride_src]
Array
<
PrimExpr
>
new_args
=
{
call
->
args
[
0
],
// dst_buf
base_i_dst
,
// 最终 dst 偏移
call
->
args
[
2
],
// src_buf
final_src_offset
,
// 最终 src 偏移
i_extent
,
// i 循环次数 (无循环时为 0)
i_stride_dst
,
// i 的 dst 步长 (无循环时为 0)
i_stride_src
,
// i 的 src 步长 (无循环时为 0)
k_stride_src
// k 的 src 步长 (即便无 i 循环,这里也能拿到 k 的步长)
call
->
args
[
0
],
base_i_dst
,
call
->
args
[
2
],
final_src_offset
,
i_extent
,
i_stride_dst
,
i_stride_src
,
k_stride_src
};
return
Evaluate
(
Call
(
call
->
dtype
,
call
->
op
,
new_args
));
...
...
@@ -78,7 +70,6 @@ private:
if
(
!
in_unrolled_i_
)
{
if
(
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
())
{
static
const
Op
&
dcu_copy_op
=
Op
::
Get
(
"tl.dcu_async_copy"
);
// 只要参数个数不是 8 (我们重写后的目标个数),就进行处理
if
(
call
->
op
.
same_as
(
dcu_copy_op
)
&&
call
->
args
.
size
()
!=
8
)
{
return
RewriteAsyncCopy
(
call
,
Var
(),
make_zero
(
DataType
::
Int
(
32
)));
}
...
...
@@ -108,7 +99,6 @@ private:
if
(
const
CallNode
*
call
=
eval
->
value
.
as
<
CallNode
>
())
{
static
const
Op
&
dcu_copy_op
=
Op
::
Get
(
"tl.dcu_async_copy"
);
if
(
call
->
op
.
same_as
(
dcu_copy_op
))
{
// 还原 k 并在返回前处理重写
Stmt
result
=
RewriteAsyncCopy
(
call
,
op
->
loop_var
,
op
->
extent
);
return
result
;
}
...
...
@@ -116,7 +106,6 @@ private:
}
}
// 退出循环时清理 k 信息
if
(
is_k
)
{
k_var_
=
Var
();
k_extent_
=
PrimExpr
();
...
...
@@ -128,9 +117,7 @@ private:
return
Stmt
(
n
);
}
};
// ============================================================================
// Pass 入口
// ============================================================================
PrimFunc
SimplifyDCUAsyncCopy
(
PrimFunc
f
)
{
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
AsyncCopySimplifier
::
Run
(
std
::
move
(
n
->
body
));
...
...
tilelang/engine/phase.py
View file @
b14f201e
...
...
@@ -222,25 +222,16 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
PipelinePlanning
()(
mod
)
mod
=
tilelang
.
transform
.
RegisterPipelinePlanning
()(
mod
)
print
(
"OptimizeForTarget"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
InjectRegisterSoftwarePipeline
()(
mod
)
print
(
"OptimizeForTarget2"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
InjectSoftwarePipeline
()(
mod
)
print
(
"OptimizeForTarget2"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
MergeIfStmt
()(
mod
)
if
allow_fence_proxy
(
target
=
target
):
# in hopper device, wgmma is an async proxy
# so we need to inject a fence proxy before it
mod
=
tilelang
.
transform
.
InjectFenceProxy
()(
mod
)
print
(
"OptimizeForTarget2.5"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
LowerOpaqueBlock
()(
mod
)
mod
=
tilelang
.
transform
.
Simplify
()(
mod
)
mod
=
tir
.
transform
.
NarrowDataType
(
32
)(
mod
)
...
...
@@ -311,8 +302,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if
dcu_async_copy_supported
(
target
):
print
(
"--------------support dcu async copy------------------"
)
mod
=
tilelang
.
transform
.
LowerSharedGlobalCopy
()(
mod
)
print
(
"222222222"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
FixDCUWaitCount
()(
mod
)
mod
=
tilelang
.
transform
.
InjectBLocalLayoutTransform
()(
mod
)
print
(
"InjectBLocalLayoutTransform ............"
)
...
...
@@ -321,8 +310,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
print
(
"InjectDSRead ............"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
InsertAsyncMMAFence
()(
mod
)
print
(
"333333333"
)
print
(
mod
)
# Register pipeline planning only writes software_pipeline annotations.
# We must inject after planning so prologue/body/epilogue are materialized.
# mod = tilelang.transform.RegisterPipelinePlanning()(mod)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment