Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
dd91b1e0
"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "54ec5c4cbec31473935a899fa3c03e732d393866"
Commit
dd91b1e0
authored
Apr 27, 2026
by
qisan
Browse files
Feats: vectorize async copy
parent
41887aed
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
355 additions
and
119 deletions
+355
-119
examples/gemm/example_gemm.py
examples/gemm/example_gemm.py
+1
-1
src/op/builtin.cc
src/op/builtin.cc
+1
-1
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+33
-24
src/transform/dcu_async_copy_pipeline.cc
src/transform/dcu_async_copy_pipeline.cc
+127
-0
src/transform/inject_pipeline.cc
src/transform/inject_pipeline.cc
+10
-0
src/transform/lower_dcu_resource.cc
src/transform/lower_dcu_resource.cc
+48
-21
src/transform/thread_storage_sync.cc
src/transform/thread_storage_sync.cc
+4
-3
src/transform/vectorize_dcu_async_copy.cc
src/transform/vectorize_dcu_async_copy.cc
+65
-41
tilelang/engine/phase.py
tilelang/engine/phase.py
+12
-11
tilelang/intrinsics/mmac_macro_generator.py
tilelang/intrinsics/mmac_macro_generator.py
+41
-14
tilelang/tileop/gemm/__init__.py
tilelang/tileop/gemm/__init__.py
+4
-0
tilelang/tileop/gemm/gemm_mmac.py
tilelang/tileop/gemm/gemm_mmac.py
+4
-2
tilelang/transform/__init__.py
tilelang/transform/__init__.py
+5
-1
No files found.
examples/gemm/example_gemm.py
View file @
dd91b1e0
...
@@ -10,7 +10,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
...
@@ -10,7 +10,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.fl
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
B
:
T
.
Tensor
((
K
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
C
:
T
.
Tensor
((
M
,
N
),
dtype
),
):
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
256
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
512
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
A_shared
=
T
.
alloc_shared
((
block_M
,
block_K
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
block_K
,
block_N
),
dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
src/op/builtin.cc
View file @
dd91b1e0
...
@@ -389,7 +389,7 @@ TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
...
@@ -389,7 +389,7 @@ TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
//
//
TIR_DEFINE_TL_BUILTIN
(
dcu_async_copy
)
TIR_DEFINE_TL_BUILTIN
(
dcu_async_copy
)
.
set_num_inputs
(
6
)
.
set_num_inputs
(
4
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
...
...
src/target/codegen_hip.cc
View file @
dd91b1e0
...
@@ -793,8 +793,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
...
@@ -793,8 +793,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
<<
", "
<<
condition
<<
");
\n
"
;
<<
", "
<<
condition
<<
");
\n
"
;
}
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_commit_group
()))
{
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_commit_group
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_commit_group
\n
"
)
;
;
print_extern_call_stmt
(
"tl::cp_async_commit"
);
//
print_extern_call_stmt("tl::cp_async_commit");
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_wait_group
()))
{
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_wait_group
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: ptx_wait_group
\n
"
);
printf
(
"[DEBUG VisitExpr_] Branch: ptx_wait_group
\n
"
);
int
n
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
int
n
=
Downcast
<
IntImm
>
(
op
->
args
[
0
])
->
value
;
...
@@ -1103,42 +1103,51 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
...
@@ -1103,42 +1103,51 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
}
}
else
if
(
op
->
op
.
same_as
(
Op
::
Get
(
"tl.dcu_async_copy"
)))
{
else
if
(
op
->
op
.
same_as
(
Op
::
Get
(
"tl.dcu_async_copy"
)))
{
// 1. 提取模板参数 (IntImm 直接取值)
auto
get_base_expr
=
[
this
](
const
PrimExpr
&
e
)
->
std
::
string
{
if
(
const
auto
*
ramp
=
e
.
as
<
tvm
::
tir
::
RampNode
>
())
{
// 如果是 Ramp,只打印它的起始位置 (base)
return
this
->
PrintExpr
(
ramp
->
base
);
}
// 否则正常打印
return
this
->
PrintExpr
(
e
);
};
// 辅助函数:尝试获取整数常量
auto
get_int_const
=
[](
const
PrimExpr
&
e
)
->
int
{
auto
get_int_const
=
[](
const
PrimExpr
&
e
)
->
int
{
if
(
const
auto
*
val
=
e
.
as
<
IntImmNode
>
())
return
static_cast
<
int
>
(
val
->
value
);
if
(
const
auto
*
val
=
e
.
as
<
IntImmNode
>
())
return
static_cast
<
int
>
(
val
->
value
);
return
0
;
return
0
;
};
};
int
N
=
16
;
// 1. 静态模板参数 (按要求仅保留 N 和 smem_offset)
int
smem_offset
=
0
;
int
N
=
16
;
int
load_count
=
get_int_const
(
op
->
args
[
4
]);
int
i_sstride
=
get_int_const
(
op
->
args
[
5
]);
int
i_gstride
=
get_int_const
(
op
->
args
[
6
]);
int
k_gstride
=
get_int_const
(
op
->
args
[
7
]);
// 2. 将运行时参数打印到字符串中 (防止直接操作 stream 导致冲突)
// 2. 解析 IR 参数
std
::
string
dst_ptr
=
this
->
PrintExpr
(
op
->
args
[
0
]);
// args[0]: dst_ptr (buf_dyn_shmem)
std
::
string
dst_off
=
this
->
PrintExpr
(
op
->
args
[
1
]);
// args[1]: dst_ramp (T.Ramp...)
std
::
string
src_res
=
this
->
PrintExpr
(
op
->
args
[
2
]);
// args[2]: src_res (A_dcu_res)
std
::
string
src_off
=
this
->
PrintExpr
(
op
->
args
[
3
]);
// args[3]: src_ramp (T.Ramp...)
// args[4]: load_count (1)
std
::
string
dst_ptr
=
this
->
PrintExpr
(
op
->
args
[
0
]);
// 使用新定义的 get_base_expr 避开 lanes > 4 的检查
std
::
string
dst_off
=
get_base_expr
(
op
->
args
[
1
]);
std
::
string
src_res
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
src_off
=
get_base_expr
(
op
->
args
[
3
]);
// 3.
仿照范例进行流
输出
// 3.
生成
输出
流
this
->
PrintIndent
();
this
->
PrintIndent
();
// 模板参数仅保留 N, smem_offset 和动态提取的 load_count
this
->
stream
<<
"tl::cp_async_gs<"
this
->
stream
<<
"tl::cp_async_gs<"
<<
N
<<
", "
<<
N
<<
">("
;
<<
smem_offset
<<
", "
<<
load_count
<<
", "
<<
i_sstride
<<
", "
<<
i_gstride
<<
", "
<<
k_gstride
<<
">("
;
// 拼接第一个参数:(char*)dst + dst_off
// 打印函数参数
// 处理目标地址: ((char*)ptr + offset)
this
->
stream
<<
"((char*)"
<<
dst_ptr
<<
" + "
<<
dst_off
<<
"), "
;
this
->
stream
<<
"((char*)"
<<
dst_ptr
<<
" + "
<<
dst_off
<<
"), "
;
//
拼接第二个参数:src_res
//
打印源资源指针
this
->
stream
<<
src_res
<<
", "
;
this
->
stream
<<
src_res
<<
", "
;
//
拼接第三个参数:src_off
//
打印源偏移
this
->
stream
<<
src_off
<<
");
\n
"
;
this
->
stream
<<
src_off
<<
");
\n
"
;
}
}
else
{
else
{
...
...
src/transform/dcu_async_copy_pipeline.cc
0 → 100644
View file @
dd91b1e0
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/analysis.h>
using
namespace
tvm
::
tir
;
using
tvm
::
ffi
::
GetRef
;
using
tvm
::
ffi
::
make_object
;
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
ffi
::
Array
;
using
ffi
::
String
;
class
ROCmWaitCountRewriter
:
public
StmtMutator
{
public:
static
Stmt
Substitute
(
Stmt
stmt
)
{
return
ROCmWaitCountRewriter
()(
stmt
);
}
private:
// 辅助函数:统计一个代码块内 async 指令的总数
int
CountAsyncOps
(
const
Stmt
&
stmt
)
{
int
total_count
=
0
;
struct
Visitor
:
public
StmtExprVisitor
{
int
count
=
0
;
void
VisitStmt_
(
const
ForNode
*
op
)
override
{
// 如果内部还有循环(比如 T.unroll),需要乘上循环次数
int
current_count
=
count
;
count
=
0
;
StmtExprVisitor
::
VisitStmt_
(
op
);
int
loop_count
=
0
;
if
(
const
auto
*
extent
=
op
->
extent
.
as
<
IntImmNode
>
())
{
loop_count
=
static_cast
<
int
>
(
extent
->
value
);
}
else
{
// 如果是非固定长度循环,这在流水线中很少见,默认按1处理或报警
loop_count
=
1
;
}
int
body_count
=
count
;
count
=
current_count
+
(
body_count
*
loop_count
);
}
void
VisitExpr_
(
const
CallNode
*
op
)
override
{
// 识别 ptx_cp_async 或对应的异步访存 Op
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async
())
||
op
->
op
.
same_as
(
Op
::
Get
(
"tl.dcu_async_copy"
)))
{
LOG
(
INFO
)
<<
"Found async copy: "
<<
GetRef
<
Call
>
(
op
);
count
++
;
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
// 兼容某些实现中把 cp_async 放在 Evaluate 里的情况
void
VisitStmt_
(
const
EvaluateNode
*
op
)
override
{
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
visitor
;
visitor
(
stmt
);
return
visitor
.
count
;
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
override
{
// 1. 我们假设流水线的主循环是核心作用域
// 先扫描该循环体内部每一轮会发出多少个 async 操作
int
ops_per_iter
=
CountAsyncOps
(
op
->
body
);
// 如果没有异步操作,直接跳过
if
(
ops_per_iter
==
0
)
return
StmtMutator
::
VisitStmt_
(
op
);
// 2. 进入循环内部进行修改,记录当前的倍数
int
old_multiplier
=
multiplier_
;
multiplier_
=
ops_per_iter
;
Stmt
new_body
=
this
->
VisitStmt
(
op
->
body
);
multiplier_
=
old_multiplier
;
if
(
new_body
.
same_as
(
op
->
body
))
return
GetRef
<
Stmt
>
(
op
);
auto
n
=
CopyOnWrite
(
op
);
n
->
body
=
std
::
move
(
new_body
);
return
Stmt
(
n
);
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
override
{
if
(
op
->
attr_key
==
"async_wait_inflight_count"
&&
multiplier_
>
0
)
{
// 获取原有的 wait 组数 (比如 1)
if
(
auto
int_imm
=
op
->
value
.
as
<
IntImmNode
>
())
{
// 计算 ROCm 的指令数: N_groups * Ops_per_group
int64_t
new_cont
=
int_imm
->
value
*
multiplier_
;
LOG
(
INFO
)
<<
"Original wait count: "
<<
new_cont
<<
", async ops per iter: "
<<
multiplier_
;
// 返回修改后的节点
return
AttrStmt
(
op
->
node
,
op
->
attr_key
,
make_const
(
DataType
::
Int
(
32
),
new_cont
),
op
->
body
);
}
}
return
StmtMutator
::
VisitStmt_
(
op
);
}
int
multiplier_
=
0
;
// 当前作用域下的指令倍率
};
// 包装成标准的 TVM Pass
namespace
transform
{
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
FixDCUWaitCount
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
ROCmWaitCountRewriter
::
Substitute
(
std
::
move
(
n
->
body
));
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"FixDCUWaitCount"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
tvm
::
ffi
::
reflection
::
GlobalDef
().
def
(
"tl.transform.FixDCUWaitCount"
,
FixDCUWaitCount
);
}
}
// namespace transform
}
// namespace tl
}
// namespace tvm
\ No newline at end of file
src/transform/inject_pipeline.cc
View file @
dd91b1e0
...
@@ -148,6 +148,7 @@ private:
...
@@ -148,6 +148,7 @@ private:
new_args
.
Set
(
i
+
1
,
new_index
);
new_args
.
Set
(
i
+
1
,
new_index
);
}
}
}
}
LOG
(
INFO
)
<<
"Rewriting buffer access "
<<
call
<<
" to "
<<
Call
(
call
->
dtype
,
call
->
op
,
new_args
,
call
->
span
);
return
Call
(
call
->
dtype
,
call
->
op
,
new_args
,
call
->
span
);
return
Call
(
call
->
dtype
,
call
->
op
,
new_args
,
call
->
span
);
}
}
...
@@ -166,6 +167,7 @@ private:
...
@@ -166,6 +167,7 @@ private:
for
(
const
Buffer
&
alloc_buffer
:
op
->
alloc_buffers
)
{
for
(
const
Buffer
&
alloc_buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
erase
(
alloc_buffer
->
data
);
buffer_data_to_buffer_
.
erase
(
alloc_buffer
->
data
);
}
}
LOG
(
INFO
)
<<
"Rewriting block "
<<
GetRef
<
Block
>
(
op
)
<<
" to "
<<
GetRef
<
Block
>
(
n
);
return
block
;
return
block
;
}
}
...
@@ -309,6 +311,7 @@ public:
...
@@ -309,6 +311,7 @@ public:
}
}
Block
block
=
MakeBlock
(
stmt
,
buffer_data_to_buffer_
);
Block
block
=
MakeBlock
(
stmt
,
buffer_data_to_buffer_
);
block
.
CopyOnWrite
()
->
alloc_buffers
=
std
::
move
(
alloc_buffers
);
block
.
CopyOnWrite
()
->
alloc_buffers
=
std
::
move
(
alloc_buffers
);
LOG
(
INFO
)
<<
"Final rewritten pipeline block: "
<<
block
;
return
BlockRealize
({},
Bool
(
true
),
block
);
return
BlockRealize
({},
Bool
(
true
),
block
);
}
}
...
@@ -631,6 +634,9 @@ private:
...
@@ -631,6 +634,9 @@ private:
n
->
body
=
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_queue_scope
,
stage_id
,
n
->
body
=
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_queue_scope
,
stage_id
,
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_inflight_count
,
AttrStmt
(
zero
,
tir
::
attr
::
async_wait_inflight_count
,
pw
.
wait_count
,
n
->
body
));
pw
.
wait_count
,
n
->
body
));
LOG
(
INFO
)
<<
"Inserting async_wait with count "
<<
pw
.
wait_count
<<
" before block with order "
<<
new_blocks
[
pw
.
insert_before
].
order
<<
" for async stage "
<<
stage_id
;
}
}
}
}
...
@@ -1102,6 +1108,7 @@ private:
...
@@ -1102,6 +1108,7 @@ private:
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
}
}
}
}
LOG
(
INFO
)
<<
"Finished rewriting the pipeline loop with body:
\n
"
<<
pipeline
;
return
pipeline
;
return
pipeline
;
}
}
...
@@ -1121,6 +1128,7 @@ private:
...
@@ -1121,6 +1128,7 @@ private:
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
}
}
LOG
(
INFO
)
<<
"Rewriting blockddd "
<<
block
;
return
block
;
return
block
;
}
}
...
@@ -1158,6 +1166,8 @@ tir::transform::Pass InjectSoftwarePipeline() {
...
@@ -1158,6 +1166,8 @@ tir::transform::Pass InjectSoftwarePipeline() {
auto
*
fptr
=
f
.
CopyOnWrite
();
auto
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
software_pipeline
::
PipelineInjector
::
Inject
(
f
);
fptr
->
body
=
software_pipeline
::
PipelineInjector
::
Inject
(
f
);
fptr
->
body
=
ConvertSSA
(
std
::
move
(
fptr
->
body
));
fptr
->
body
=
ConvertSSA
(
std
::
move
(
fptr
->
body
));
LOG
(
INFO
)
<<
"Finished injecting software pipeline for PrimFunc "
<<
f
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
).
value_or
(
"<unknown>"
)
<<
", the transformed body is:
\n
"
<<
fptr
->
body
;
return
f
;
return
f
;
};
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectSoftwarePipeline"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectSoftwarePipeline"
,
{});
...
...
src/transform/lower_dcu_resource.cc
View file @
dd91b1e0
...
@@ -94,6 +94,7 @@ CollectResult CollectResources(const Stmt& body) {
...
@@ -94,6 +94,7 @@ CollectResult CollectResources(const Stmt& body) {
CollectResult
result
;
CollectResult
result
;
private:
private:
bool
in_async
{
false
};
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>
loop_vars_
;
std
::
unordered_set
<
const
tvm
::
tir
::
VarNode
*>
loop_vars_
;
std
::
vector
<
const
tvm
::
tir
::
StmtNode
*>
scope_stack_
;
// 追踪当前遍历的 AST 路径
std
::
vector
<
const
tvm
::
tir
::
StmtNode
*>
scope_stack_
;
// 追踪当前遍历的 AST 路径
bool
IsSharedScope
(
const
Buffer
&
buf
)
{
bool
IsSharedScope
(
const
Buffer
&
buf
)
{
...
@@ -105,27 +106,36 @@ CollectResult CollectResources(const Stmt& body) {
...
@@ -105,27 +106,36 @@ CollectResult CollectResources(const Stmt& body) {
return
s
==
"global"
||
s
==
""
;
return
s
==
"global"
||
s
==
""
;
}
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
override
{
void
VisitStmt_
(
const
AttrStmtNode
*
attr
)
override
{
scope_stack_
.
push_back
(
op
);
scope_stack_
.
push_back
(
attr
);
if
(
op
->
attr_key
==
tvm
::
tir
::
attr
::
thread_extent
)
{
if
(
attr
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
// 1. 获取 IterVar
// 1. 获取 IterVar
auto
iv
=
op
->
node
.
as
<
tvm
::
tir
::
IterVarNode
>
();
auto
iv
=
attr
->
node
.
as
<
tvm
::
tir
::
IterVarNode
>
();
const
std
::
string
&
tag
=
iv
->
thread_tag
;
const
std
::
string
&
tag
=
iv
->
thread_tag
;
// 2. 只有当 tag 包含 "threadIdx" 时才加入 (过滤掉 blockIdx)
// 比如: "threadIdx.x", "threadIdx.y", "threadIdx.z"
if
(
tag
.
find
(
"threadIdx"
)
!=
std
::
string
::
npos
)
{
if
(
tag
.
find
(
"threadIdx"
)
!=
std
::
string
::
npos
)
{
tvm
::
tir
::
Var
thread_var
=
iv
->
var
;
tvm
::
tir
::
Var
thread_var
=
iv
->
var
;
LOG
(
INFO
)
<<
"Entering thread scope: "
<<
tag
<<
" with var "
<<
thread_var
->
name_hint
;
loop_vars_
.
insert
(
thread_var
.
get
());
loop_vars_
.
insert
(
thread_var
.
get
());
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
attr
);
loop_vars_
.
erase
(
thread_var
.
get
());
loop_vars_
.
erase
(
thread_var
.
get
());
}
else
{
}
else
{
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
// 如果是 blockIdx 或其他,直接跳过当前层继续往下走
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
attr
);
}
}
}
else
if
(
attr
->
attr_key
==
tir
::
attr
::
async_scope
)
{
ICHECK
(
in_async
==
false
)
<<
"Nested async scopes not supported"
;
in_async
=
true
;
StmtExprVisitor
::
VisitStmt_
(
attr
);
in_async
=
false
;
}
else
{
StmtExprVisitor
::
VisitStmt_
(
attr
);
}
}
scope_stack_
.
pop_back
();
scope_stack_
.
pop_back
();
}
}
...
@@ -145,14 +155,15 @@ CollectResult CollectResources(const Stmt& body) {
...
@@ -145,14 +155,15 @@ CollectResult CollectResources(const Stmt& body) {
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
LOG
(
INFO
)
<<
"Visiting BufferStore: "
<<
op
->
buffer
->
name
;
Buffer
dst
=
op
->
buffer
;
Buffer
dst
=
op
->
buffer
;
if
(
IsSharedScope
(
dst
)
&&
op
->
value
.
defined
())
{
if
(
IsSharedScope
(
dst
)
&&
op
->
value
.
defined
()
&&
in_async
)
{
if
(
const
auto
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
())
{
if
(
const
auto
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
())
{
Buffer
src
=
load
->
buffer
;
Buffer
src
=
load
->
buffer
;
if
(
IsGlobalScope
(
src
))
{
if
(
IsGlobalScope
(
src
))
{
const
StmtNode
*
target
=
op
;
if
(
result
.
inject_target
==
nullptr
)
{
if
(
result
.
inject_target
==
nullptr
)
{
// 从下往上回溯栈,寻找最内层的 thread_extent
for
(
int
i
=
scope_stack_
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
scope_stack_
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
if
(
scope_stack_
[
i
]
->
IsInstance
<
AttrStmtNode
>
())
{
if
(
scope_stack_
[
i
]
->
IsInstance
<
AttrStmtNode
>
())
{
auto
attr
=
static_cast
<
const
AttrStmtNode
*>
(
scope_stack_
[
i
]);
auto
attr
=
static_cast
<
const
AttrStmtNode
*>
(
scope_stack_
[
i
]);
...
@@ -198,10 +209,10 @@ CollectResult CollectResources(const Stmt& body) {
...
@@ -198,10 +209,10 @@ CollectResult CollectResources(const Stmt& body) {
VariableEliminator
eliminator
(
loop_vars_
);
VariableEliminator
eliminator
(
loop_vars_
);
tvm
::
arith
::
Analyzer
analyzer
;
tvm
::
arith
::
Analyzer
analyzer
;
Array
<
PrimExpr
>
base_indices
;
Array
<
PrimExpr
>
base_indices
;
LOG
(
INFO
)
<<
loop_vars_
.
size
()
<<
" loop vars in context."
;
LOG
(
INFO
)
<<
loop_vars_
.
size
()
<<
" loop vars in context."
;
for
(
const
auto
*
var
:
loop_vars_
)
{
for
(
const
auto
*
var
:
loop_vars_
)
{
LOG
(
INFO
)
<<
"Loop Var: "
<<
var
->
name_hint
;
LOG
(
INFO
)
<<
"Loop Var: "
<<
var
->
name_hint
;
}
}
for
(
const
auto
&
idx
:
load
->
indices
)
{
for
(
const
auto
&
idx
:
load
->
indices
)
{
// 将所有外层循环变量 (k, i 等) 全部替换为 0
// 将所有外层循环变量 (k, i 等) 全部替换为 0
PrimExpr
no_loops
=
eliminator
(
idx
);
PrimExpr
no_loops
=
eliminator
(
idx
);
...
@@ -225,15 +236,18 @@ CollectResult CollectResources(const Stmt& body) {
...
@@ -225,15 +236,18 @@ CollectResult CollectResources(const Stmt& body) {
// 将这个绑定关系和 destination 的 shared buffer 绑死
// 将这个绑定关系和 destination 的 shared buffer 绑死
result
.
shared_alloc_to_binding
[
src
->
name
]
=
{
var
,
val
};
result
.
shared_alloc_to_binding
[
src
->
name
]
=
{
var
,
val
};
}
}
LOG
(
INFO
)
<<
"result.copies.size() = "
<<
result
.
copies
.
size
();
}
}
}
}
}
}
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
};
};
LOG
(
INFO
)
<<
"Starting resource collection..."
;
Collector
col
;
Collector
col
;
col
(
body
);
col
(
body
);
LOG
(
INFO
)
<<
"Finished resource collection. Found "
<<
col
.
result
.
copies
.
size
()
<<
" copy(s)."
;
return
col
.
result
;
return
col
.
result
;
}
}
...
@@ -253,6 +267,15 @@ private:
...
@@ -253,6 +267,15 @@ private:
const
std
::
unordered_map
<
String
,
Var
>&
global_to_var
)
const
std
::
unordered_map
<
String
,
Var
>&
global_to_var
)
:
copies_
(
copies
),
global_to_var_
(
global_to_var
)
{}
:
copies_
(
copies
),
global_to_var_
(
global_to_var
)
{}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
attr
)
{
if
(
attr
->
attr_key
==
tir
::
attr
::
async_scope
)
{
auto
body
=
this
->
VisitStmt
(
attr
->
body
);
return
body
;
}
return
StmtMutator
::
VisitStmt_
(
attr
);
// ③ 其他属性:默认保留
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
for
(
const
auto
&
copy
:
copies_
)
{
for
(
const
auto
&
copy
:
copies_
)
{
if
(
copy
.
store_stmt
.
same_as
(
GetRef
<
Stmt
>
(
op
)))
{
if
(
copy
.
store_stmt
.
same_as
(
GetRef
<
Stmt
>
(
op
)))
{
...
@@ -331,19 +354,23 @@ private:
...
@@ -331,19 +354,23 @@ private:
PrimFunc
LowerSharedGlobalCopy
(
PrimFunc
f
)
{
PrimFunc
LowerSharedGlobalCopy
(
PrimFunc
f
)
{
auto
*
n
=
f
.
CopyOnWrite
();
auto
*
n
=
f
.
CopyOnWrite
();
// 1. 收集信息并定位目标注入点
// 收集信息
LOG
(
INFO
)
<<
"Starting LowerSharedGlobalCopy transformation..."
;
auto
res
=
CollectResources
(
n
->
body
);
auto
res
=
CollectResources
(
n
->
body
);
if
(
res
.
copies
.
empty
())
return
f
;
if
(
res
.
copies
.
empty
()){
LOG
(
INFO
)
<<
"No shared-global copy patterns detected. Skipping transformation."
;
return
f
;
}
// 【核心修改】:2. 先注入 LetStmt!
LOG
(
INFO
)
<<
"Replaced "
<<
res
.
copies
.
size
()
<<
" copy(s) with dcu_async_copy."
;
//
此时使用的 n->body 是原始 AST,res.inject_target 指针百分之百匹配。
//
注入res声明
Stmt
injected
=
ResourceInjector
::
Run
(
n
->
body
,
res
.
shared_alloc_to_binding
,
res
.
inject_target
);
Stmt
injected
=
ResourceInjector
::
Run
(
n
->
body
,
res
.
shared_alloc_to_binding
,
res
.
inject_target
);
// 3. 替换拷贝语句
// 替换拷贝语句
// injected 是套了 LetStmt 的新 AST,但底层的 BufferStore 还是原来的,可以被正常替换。
Stmt
replaced
=
StoreReplacer
::
Run
(
injected
,
res
.
copies
,
res
.
global_to_res_var
);
Stmt
replaced
=
StoreReplacer
::
Run
(
injected
,
res
.
copies
,
res
.
global_to_res_var
);
// 4. 写回 PrimFunc
// 写回
n
->
body
=
std
::
move
(
replaced
);
n
->
body
=
std
::
move
(
replaced
);
return
GetRef
<
PrimFunc
>
(
n
);
return
GetRef
<
PrimFunc
>
(
n
);
...
...
src/transform/thread_storage_sync.cc
View file @
dd91b1e0
...
@@ -52,6 +52,7 @@ public:
...
@@ -52,6 +52,7 @@ public:
// The syncs inserted before each statement
// The syncs inserted before each statement
std
::
unordered_set
<
const
Object
*>
syncs_inserted_
;
std
::
unordered_set
<
const
Object
*>
syncs_inserted_
;
std
::
unordered_set
<
const
Object
*>
barrier_inserted_
;
protected:
protected:
bool
Enabled
(
const
VarNode
*
buf
,
const
StorageScope
&
scope
)
const
final
{
bool
Enabled
(
const
VarNode
*
buf
,
const
StorageScope
&
scope
)
const
final
{
...
@@ -815,9 +816,9 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) {
...
@@ -815,9 +816,9 @@ PrimFunc TileLangThreadSync(PrimFunc func, const std::string &storage_scope) {
StorageScope
sync_scope
=
StorageScope
::
Create
(
storage_scope
);
StorageScope
sync_scope
=
StorageScope
::
Create
(
storage_scope
);
auto
*
n
=
func
.
CopyOnWrite
();
auto
*
n
=
func
.
CopyOnWrite
();
auto
stmt
=
n
->
body
;
auto
stmt
=
n
->
body
;
if
(
sync_scope
.
rank
==
StorageRank
::
kShared
&&
sync_scope
.
tag
.
empty
())
{
//
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag.empty()) {
stmt
=
ThreadSyncAfterWaitQueueInserter
(
sync_scope
)(
stmt
);
//
stmt = ThreadSyncAfterWaitQueueInserter(sync_scope)(stmt);
}
//
}
TileLangThreadSyncPlanner
planner
(
sync_scope
);
TileLangThreadSyncPlanner
planner
(
sync_scope
);
for
(
const
auto
&
[
_
,
buffer
]
:
func
->
buffer_map
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
func
->
buffer_map
)
{
planner
.
SetBufferDataToBuffer
(
buffer
->
data
,
buffer
);
planner
.
SetBufferDataToBuffer
(
buffer
->
data
,
buffer
);
...
...
src/transform/vectorize_dcu_async_copy.cc
View file @
dd91b1e0
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/builtin.h>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <vector>
#include <vector>
...
@@ -27,72 +28,95 @@ public:
...
@@ -27,72 +28,95 @@ public:
private:
private:
arith
::
Analyzer
analyzer_
;
arith
::
Analyzer
analyzer_
;
Var
k_var_
;
Var
k_var_
;
PrimExpr
k_extent_
;
// 新增:记录 k 循环的次数
PrimExpr
k_extent_
;
bool
in_unrolled_i_
=
false
;
// 通用的步长提取函数:从 expr 中提取指定 var 的步长,并返回剩余的 base
std
::
pair
<
PrimExpr
,
PrimExpr
>
ExtractStride
(
PrimExpr
expr
,
Var
var
)
{
std
::
pair
<
PrimExpr
,
PrimExpr
>
ExtractStride
(
PrimExpr
expr
,
Var
var
)
{
if
(
!
var
.
defined
())
return
{
expr
,
make_zero
(
expr
.
dtype
())};
if
(
!
var
.
defined
())
return
{
expr
,
make_zero
(
expr
->
dtype
)};
PrimExpr
base
=
tvm
::
tir
::
Substitute
(
expr
,
{{
var
,
make_zero
(
var
.
dtype
())}});
PrimExpr
base
=
tvm
::
tir
::
Substitute
(
expr
,
{{
var
,
make_zero
(
var
.
dtype
())}});
PrimExpr
plus_one
=
tvm
::
tir
::
Substitute
(
expr
,
{{
var
,
make_const
(
var
.
dtype
(),
1
)}});
PrimExpr
plus_one
=
tvm
::
tir
::
Substitute
(
expr
,
{{
var
,
make_const
(
var
.
dtype
(),
1
)}});
PrimExpr
stride
=
analyzer_
.
Simplify
(
plus_one
-
base
);
PrimExpr
stride
=
analyzer_
.
Simplify
(
plus_one
-
base
);
return
{
analyzer_
.
Simplify
(
base
),
stride
};
return
{
analyzer_
.
Simplify
(
base
),
stride
};
}
}
// 核心重写逻辑
Stmt
RewriteAsyncCopy
(
const
CallNode
*
call
,
Var
i_var
,
PrimExpr
i_extent
)
{
// 1. 预处理:剥离 RampNode 获得基础偏移
PrimExpr
raw_dst_off
=
call
->
args
[
1
];
PrimExpr
raw_src_off
=
call
->
args
[
3
];
if
(
const
RampNode
*
r
=
raw_dst_off
.
as
<
RampNode
>
())
raw_dst_off
=
r
->
base
;
if
(
const
RampNode
*
r
=
raw_src_off
.
as
<
RampNode
>
())
raw_src_off
=
r
->
base
;
// 2. 提取 i 的步长
auto
[
base_i_dst
,
i_stride_dst
]
=
ExtractStride
(
raw_dst_off
,
i_var
);
auto
[
base_i_src
,
i_stride_src
]
=
ExtractStride
(
raw_src_off
,
i_var
);
// 3. 提取 k 的步长 (始终尝试从 base_i_src 中提取,无论是否存在 i 循环)
// 只要 k_var_ 在外层循环中定义了,这里就能提取出非 0 的步长
auto
[
final_src_offset
,
k_stride_src
]
=
ExtractStride
(
base_i_src
,
k_var_
);
// 构造最初要求的 8 个参数:
// [dst, dst_off, src, src_off, i_extent, i_stride_dst, i_stride_src, k_stride_src]
Array
<
PrimExpr
>
new_args
=
{
call
->
args
[
0
],
// dst_buf
base_i_dst
,
// 最终 dst 偏移
call
->
args
[
2
],
// src_buf
final_src_offset
,
// 最终 src 偏移
i_extent
,
// i 循环次数 (无循环时为 0)
i_stride_dst
,
// i 的 dst 步长 (无循环时为 0)
i_stride_src
,
// i 的 src 步长 (无循环时为 0)
k_stride_src
// k 的 src 步长 (即便无 i 循环,这里也能拿到 k 的步长)
};
return
Evaluate
(
Call
(
call
->
dtype
,
call
->
op
,
new_args
));
}
// 处理无循环包裹的情况
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
if
(
!
in_unrolled_i_
)
{
if
(
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
())
{
static
const
Op
&
dcu_copy_op
=
Op
::
Get
(
"tl.dcu_async_copy"
);
// 只要参数个数不是 8 (我们重写后的目标个数),就进行处理
if
(
call
->
op
.
same_as
(
dcu_copy_op
)
&&
call
->
args
.
size
()
!=
8
)
{
return
RewriteAsyncCopy
(
call
,
Var
(),
make_zero
(
DataType
::
Int
(
32
)));
}
}
}
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
//
1.
记录 k
的
信息
// 记录 k 信息
(假设 k 在外层)
bool
is_k
=
(
op
->
loop_var
->
name_hint
==
"k"
);
bool
is_k
=
(
op
->
loop_var
->
name_hint
==
"k"
);
if
(
is_k
)
{
if
(
is_k
)
{
k_var_
=
op
->
loop_var
;
k_var_
=
op
->
loop_var
;
k_extent_
=
op
->
extent
;
// 获取 k 的循环次数 (如 64)
k_extent_
=
op
->
extent
;
}
}
// 2. 递归访问子节点
bool
is_unrolled
=
(
op
->
kind
==
ForKind
::
kUnrolled
);
bool
prev_in_unrolled
=
in_unrolled_i_
;
if
(
is_unrolled
)
in_unrolled_i_
=
true
;
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
in_unrolled_i_
=
prev_in_unrolled
;
// 3. 处理 Async Copy 简化
if
(
is_unrolled
)
{
if
(
op
->
kind
==
ForKind
::
kUnrolled
)
{
if
(
const
EvaluateNode
*
eval
=
body
.
as
<
EvaluateNode
>
())
{
if
(
const
EvaluateNode
*
eval
=
body
.
as
<
EvaluateNode
>
())
{
if
(
const
CallNode
*
call
=
eval
->
value
.
as
<
CallNode
>
())
{
if
(
const
CallNode
*
call
=
eval
->
value
.
as
<
CallNode
>
())
{
static
const
Op
&
dcu_copy_op
=
Op
::
Get
(
"tl.dcu_async_copy"
);
static
const
Op
&
dcu_copy_op
=
Op
::
Get
(
"tl.dcu_async_copy"
);
if
(
call
->
op
.
same_as
(
dcu_copy_op
))
{
if
(
call
->
op
.
same_as
(
dcu_copy_op
))
{
// 还原 k 并在返回前处理重写
Var
i_var
=
op
->
loop_var
;
Stmt
result
=
RewriteAsyncCopy
(
call
,
op
->
loop_var
,
op
->
extent
);
PrimExpr
i_extent
=
op
->
extent
;
// 获取 i 的循环次数 (如 2)
return
result
;
auto
get_i_info
=
[
&
](
PrimExpr
offset
)
{
if
(
const
RampNode
*
ramp
=
offset
.
as
<
RampNode
>
())
{
auto
[
base
,
stride
]
=
ExtractStride
(
ramp
->
base
,
i_var
);
return
std
::
make_pair
(
base
,
stride
);
}
return
ExtractStride
(
offset
,
i_var
);
};
// 提取 i 的步长
auto
[
base_dst
,
i_stride_dst
]
=
get_i_info
(
call
->
args
[
1
]);
auto
[
base_src
,
i_stride_src
]
=
get_i_info
(
call
->
args
[
3
]);
// 提取 k 的步长 (从 base_src 继续解构)
auto
[
final_src_offset
,
k_stride_src
]
=
ExtractStride
(
base_src
,
k_var_
);
// 构造新的参数列表,包含循环次数
// 建议参数顺序:[dst, dst_off, src, src_off, size, i_extent, i_stride_dst, i_stride_src, k_stride_src]
// 这里的 size 保持原样 (如 8),i_extent 传入 2
Array
<
PrimExpr
>
new_args
=
{
call
->
args
[
0
],
// dst_buf
base_dst
,
// 基础 dst 偏移
call
->
args
[
2
],
// src_buf
final_src_offset
,
// 基础 src 偏移
i_extent
,
// i 循环次数
i_stride_dst
,
// i 的 dst 步长
i_stride_src
,
// i 的 src 步长
k_stride_src
// k 的 src 步长
};
return
Evaluate
(
Call
(
call
->
dtype
,
call
->
op
,
new_args
));
}
}
}
}
}
}
}
}
// 退出循环时清理 k 信息
if
(
is_k
)
{
if
(
is_k
)
{
k_var_
=
Var
();
k_var_
=
Var
();
k_extent_
=
PrimExpr
();
k_extent_
=
PrimExpr
();
...
...
tilelang/engine/phase.py
View file @
dd91b1e0
...
@@ -213,8 +213,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -213,8 +213,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
IfStmtBinding
()(
mod
)
mod
=
tilelang
.
transform
.
IfStmtBinding
()(
mod
)
mod
=
tilelang
.
transform
.
PlanAndUpdateBufferAllocationLocation
()(
mod
)
mod
=
tilelang
.
transform
.
PlanAndUpdateBufferAllocationLocation
()(
mod
)
print
(
"OptimizeForTarget"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
PipelinePlanning
()(
mod
)
mod
=
tilelang
.
transform
.
PipelinePlanning
()(
mod
)
mod
=
tilelang
.
transform
.
InjectSoftwarePipeline
()(
mod
)
mod
=
tilelang
.
transform
.
InjectSoftwarePipeline
()(
mod
)
mod
=
tilelang
.
transform
.
MergeIfStmt
()(
mod
)
mod
=
tilelang
.
transform
.
MergeIfStmt
()(
mod
)
if
allow_fence_proxy
(
target
=
target
):
if
allow_fence_proxy
(
target
=
target
):
# in hopper device, wgmma is an async proxy
# in hopper device, wgmma is an async proxy
...
@@ -270,15 +273,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -270,15 +273,15 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
print
(
"OptimizeForTarget2"
)
print
(
mod
)
# Inject PTX async copy must behind the thread sync pass
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
# as ptx async copy won't be recognized as a valid buffer load
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
if
not
dcu_async_copy_supported
(
target
):
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
# Inject ds_read for shared to register memory copy on DCU
# Inject ds_read for shared to register memory copy on DCU
mod
=
tilelang
.
transform
.
InjectDSRead
()(
mod
)
# mod = tilelang.transform.InjectDSRead()(mod)
print
(
"222222222"
)
print
(
mod
)
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
mod
=
tilelang
.
transform
.
MakePackedAPI
()(
mod
)
mod
=
tilelang
.
transform
.
MakePackedAPI
()(
mod
)
...
@@ -287,13 +290,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -287,13 +290,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock
# Transform threadblock to persistent threadblock
mod
=
tilelang
.
transform
.
PersistThreadblock
()(
mod
)
mod
=
tilelang
.
transform
.
PersistThreadblock
()(
mod
)
print
(
"OptimizeForTarget"
)
print
(
mod
)
if
dcu_async_copy_supported
(
target
):
if
dcu_async_copy_supported
(
target
):
mod
=
tilelang
.
transform
.
LowerSharedGlobalCopy
()(
mod
)
mod
=
tilelang
.
transform
.
LowerSharedGlobalCopy
()(
mod
)
print
(
"OptimizeForTarget2"
)
mod
=
tilelang
.
transform
.
FixDCUWaitCount
()(
mod
)
print
(
mod
)
#
mod
=
tilelang
.
transform
.
SimplifyDCUAsyncCopy
()(
mod
)
#
mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print
(
"OptimizeForTarget3"
)
print
(
"OptimizeForTarget3"
)
print
(
mod
)
print
(
mod
)
return
mod
return
mod
tilelang/intrinsics/mmac_macro_generator.py
View file @
dd91b1e0
...
@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
...
@@ -11,7 +11,7 @@ from tilelang.utils import is_fragment
from
.mfma_layout
import
(
from
.mfma_layout
import
(
shared_16x4_to_local_64x1_layout_A
,
shared_16x4_to_local_64x1_layout_A
,
shared_4x16_to_local_64x1_layout_B
,
shared_4x16_to_local_64x1_layout_B
,
shared_16x16_to_local_64x4_layout_A
,
#
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B
,
shared_16x16_to_local_64x4_layout_B
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_A
,
shared_16x32_to_local_64x8_layout_B
,
shared_16x32_to_local_64x8_layout_B
,
...
@@ -19,7 +19,7 @@ from .mfma_layout import (
...
@@ -19,7 +19,7 @@ from .mfma_layout import (
shared_16x64_to_local_64x16_layout_B
,
shared_16x64_to_local_64x16_layout_B
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_16x4_layout_A
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x1_to_4x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_A
,
#
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x4_to_16x16_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_A
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
thread_id_shared_access_64x8_to_16x32_layout_B
,
...
@@ -27,10 +27,10 @@ from .mfma_layout import (
...
@@ -27,10 +27,10 @@ from .mfma_layout import (
thread_id_shared_access_64x16_to_16x64_layout_B
,
thread_id_shared_access_64x16_to_16x64_layout_B
,
)
)
#
from .mmac_layout import (
from
.mmac_layout
import
(
#
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_A
,
#
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_A
,
#
)
)
lift
=
convert
lift
=
convert
...
@@ -251,6 +251,21 @@ class MatrixCoreIntrinEmitter:
...
@@ -251,6 +251,21 @@ class MatrixCoreIntrinEmitter:
)
)
return
lane_id
,
warp_n
,
warp_m
return
lane_id
,
warp_n
,
warp_m
def
map_64x16
(
self
,
row
,
col
,
idx
,
warp_rows
,
tx
):
new_col
=
col
if
warp_rows
>
1
:
inter_idx_padding
=
2
else
:
inter_idx_padding
=
1
paddings
=
inter_idx_padding
*
self
.
block_row_warps
*
4
print
(
"paddings:"
,
paddings
)
new_row
=
row
+
paddings
*
((
tx
&
15
)
//
4
)
new_row
+=
(
idx
&
1
)
*
(
paddings
//
2
)
+
(
idx
//
2
)
*
16
*
2
*
self
.
block_row_warps
return
new_row
,
new_col
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
def
ldmatrix_a
(
self
,
A_local_buf
,
A_shared_buf
:
Buffer
|
BufferRegion
,
ki
,
rk
=
0
):
# share mem a needs warp number
# share mem a needs warp number
warp_num
=
self
.
block_row_warps
warp_num
=
self
.
block_row_warps
...
@@ -272,6 +287,7 @@ class MatrixCoreIntrinEmitter:
...
@@ -272,6 +287,7 @@ class MatrixCoreIntrinEmitter:
A_buf
=
A_region
.
buffer
A_buf
=
A_region
.
buffer
A_base0
=
A_region
.
region
[
-
2
].
min
A_base0
=
A_region
.
region
[
-
2
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
A_base1
=
A_region
.
region
[
-
1
].
min
print
(
"A_base0, A_base1:"
,
A_base0
,
A_base1
)
@
T
.
macro
@
T
.
macro
def
_warp_ldmatrix_a
(
def
_warp_ldmatrix_a
(
...
@@ -281,31 +297,42 @@ class MatrixCoreIntrinEmitter:
...
@@ -281,31 +297,42 @@ class MatrixCoreIntrinEmitter:
thread_binding
,
thread_binding
,
rk
=
0
,
rk
=
0
,
):
):
tx
,
_
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
# warp_n[0-256] -> {0,1,2,3}
tx
,
warp_n
,
warp_m
=
self
.
extract_thread_binding
(
thread_binding
)
# {0..3,16..19,32..35,48..51} -> 0
# {0..3,16..19,32..35,48..51} -> 0
# {4..7,20..23,36..39,52..55} -> 1
# {4..7,20..23,36..39,52..55} -> 1
# {8..11,24..27,40..43,56..59} -> 2
# {8..11,24..27,40..43,56..59} -> 2
# {12..15,28..31,44..47,60..63} -> 3
# {12..15,28..31,44..47,60..63} -> 3
warp_interval_idx
=
(
tx
&
15
)
>>
2
warp_interval_idx
=
(
tx
&
15
)
>>
2
warp_group_idx
=
(
tx
//
32
)
warp_group_idx
=
warp_n
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
# warp_rows 轮次 warp需要拆分成多轮来访问完整的行块
if
is_transposed
:
if
is_transposed
:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
# 每轮初始位置行偏移
# 每轮初始位置行偏移
row
+=
i
*
warp_row_init
# row += i * warp_row_init
# warp 组行间隔
# # warp 组行间隔
row
+=
warp_group_idx
*
4
# row += warp_group_idx * 4
# warp 内行间隔
# # warp 内行间隔
row
+=
warp_interval_idx
*
warp_row_interval
# row += warp_interval_idx * warp_row_interval
raise
NotImplementedError
(
"Transposed A with preshuffle is not implemented yet"
)
row
,
col
=
self
.
map_64x16
(
row
,
col
,
i
,
warp_rows
,
tx
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
l
,
r
=
(
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
),
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
)
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
else
:
else
:
for
i
in
T
.
serial
(
warp_rows
):
for
i
in
T
.
serial
(
warp_rows
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
for
local_id
in
T
.
vectorized
(
k_pack
*
local_size_a
):
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
row
,
col
=
T
.
meta_var
(
reverse_index_map
(
tx
,
local_id
))
l
,
r
=
(
warp_m
*
warp_row_tiles
+
i
*
micro_size_x
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
# # 每轮初始位置行偏移
# row += i * warp_row_init
# # warp 组行间隔
# row += warp_group_idx * 4
# # warp 内行间隔
# row += warp_interval_idx * warp_row_interval
row
,
col
=
self
.
map_64x16
(
row
,
col
,
i
,
warp_rows
,
tx
)
print
(
"row, col:"
,
row
,
col
)
l
,
r
=
(
warp_m
*
4
,
rk
*
chunk
+
ki
*
(
k_pack
*
micro_size_k
))
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
A_local_buf
[
i
*
k_pack
*
local_size_a
+
local_id
]
=
A_buf
[
A_base0
+
l
+
row
,
A_base1
+
r
+
col
]
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
return
_warp_ldmatrix_a
(
A_local_buf
,
A_shared_buf
,
ki
,
thread_binding
,
rk
)
...
...
tilelang/tileop/gemm/__init__.py
View file @
dd91b1e0
...
@@ -20,6 +20,10 @@ print("tileop gemm init...")
...
@@ -20,6 +20,10 @@ print("tileop gemm init...")
def
gemm_py_infer_layout
(
gemm_py
:
GemmMMA
,
target
:
Target
,
thread_bounds
:
Range
):
def
gemm_py_infer_layout
(
gemm_py
:
GemmMMA
,
target
:
Target
,
thread_bounds
:
Range
):
print
(
"tileop gemm infer_layout"
)
print
(
"tileop gemm infer_layout"
)
thread_nums
=
thread_bounds
.
extent
thread_nums
=
thread_bounds
.
extent
print
(
f
"gemm_py_infer_layout Target:
{
target
}
, thread_nums:
{
thread_nums
}
"
)
print
(
f
"gemm_py_infer_layout gemm_py:
{
gemm_py
}
"
)
t
=
gemm_py
.
infer_layout
(
target
,
thread_nums
)
print
(
f
"gemm_py_infer_layout gemm_py.A:
{
gemm_py
.
A
}
, gemm_py.B:
{
gemm_py
.
B
}
, gemm_py.C:
{
gemm_py
.
C
}
"
)
return
gemm_py
.
infer_layout
(
target
,
thread_nums
)
return
gemm_py
.
infer_layout
(
target
,
thread_nums
)
...
...
tilelang/tileop/gemm/gemm_mmac.py
View file @
dd91b1e0
...
@@ -32,8 +32,10 @@ class GemmMMAC(GemmBase):
...
@@ -32,8 +32,10 @@ class GemmMMAC(GemmBase):
if
self
.
is_gemm_ss
():
if
self
.
is_gemm_ss
():
return
{
return
{
self
.
A
:
make_swizzled_layout
(
self
.
A
),
# self.A: make_swizzled_layout(self.A, allow_pad=False),
self
.
B
:
make_swizzled_layout
(
self
.
B
),
# self.B: make_swizzled_layout(self.B, allow_pad=False),
self
.
A
:
make_linear_layout
(
self
.
A
),
self
.
B
:
make_linear_layout
(
self
.
B
),
self
.
C
:
mmac_emitter
.
make_mmac_store_layout
(
self
.
C
),
self
.
C
:
mmac_emitter
.
make_mmac_store_layout
(
self
.
C
),
}
}
elif
self
.
is_gemm_sr
():
elif
self
.
is_gemm_sr
():
...
...
tilelang/transform/__init__.py
View file @
dd91b1e0
...
@@ -559,4 +559,8 @@ def LowerSharedGlobalCopy():
...
@@ -559,4 +559,8 @@ def LowerSharedGlobalCopy():
def
SimplifyDCUAsyncCopy
():
def
SimplifyDCUAsyncCopy
():
"""SimplifyDCUAsyncCopy"""
"""SimplifyDCUAsyncCopy"""
return
_ffi_api
.
SimplifyDCUAsyncCopy
()
# type: ignore
return
_ffi_api
.
SimplifyDCUAsyncCopy
()
# type: ignore
\ No newline at end of file
def
FixDCUWaitCount
():
"""FixDCUWaitCount"""
return
_ffi_api
.
FixDCUWaitCount
()
# type: ignore
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment