Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
dd95e41b
"runtime/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "910751c64bafac204b2820d0b7836eb7f27a6751"
Commit
dd95e41b
authored
May 06, 2026
by
wangziyang
Browse files
add B_local layout transformation with loop optimization
parent
bba13746
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
135 additions
and
168 deletions
+135
-168
src/transform/inject_blocal_layout_transform.cc
src/transform/inject_blocal_layout_transform.cc
+133
-151
tilelang/engine/phase.py
tilelang/engine/phase.py
+2
-3
tilelang/transform/__init__.py
tilelang/transform/__init__.py
+0
-14
No files found.
src/transform/inject_blocal_layout_transform.cc
View file @
dd95e41b
...
...
@@ -41,162 +41,161 @@ namespace tl {
using
namespace
tir
;
/*!
* \brief Check if a statement contains B_local stores
* \brief Transformer that handles B_local layout transformation with loop optimization
*
* This transformer handles two cases:
* 1. B_local store with outer loop: halve the loop extent and double the offset
* 2. B_local store without outer loop: just double the offset
*/
bool
ContainsBLocalStore
(
const
Stmt
&
stmt
)
{
bool
found
=
false
;
tir
::
PreOrderVisit
(
stmt
,
[
&
](
const
ObjectRef
&
node
)
->
bool
{
if
(
found
)
{
return
false
;
class
BLocalLayoutTransformer
:
public
StmtExprMutator
{
public:
explicit
BLocalLayoutTransformer
(
int
expand
)
:
expand_
(
expand
)
{}
private:
int
expand_
;
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// 只处理 serial 外层循环
if
(
op
->
kind
!=
ForKind
::
kSerial
)
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
if
(
const
auto
*
store
=
node
.
as
<
BufferStoreNode
>
())
{
std
::
string
name
=
store
->
buffer
->
name
;
if
(
name
.
find
(
"B_local"
)
!=
std
::
string
::
npos
)
{
found
=
true
;
return
false
;
}
// 判断是否是 B_local 写循环
auto
store
=
op
->
body
.
as
<
BufferStoreNode
>
();
if
(
!
store
)
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
return
true
;
});
return
found
;
}
/*!
* \brief Check if this is a B_local store pattern
*
* Pattern to match:
* B_local[index] = B_shared[index_expr]
*
* Where B_shared[index_expr] is a complex expression involving:
* - thread_binding (threadIdx.x, threadIdx.y, etc.)
* - ki (iteration variable)
* - j and local_id (loop variables)
*/
bool
IsBLocalStorePattern
(
const
BufferStoreNode
*
op
,
Var
*
local_var
,
Var
*
shared_var
,
PrimExpr
*
shared_offset
)
{
// Check if store is to a local buffer named B_local
std
::
string
buffer_name
=
op
->
buffer
->
name
;
if
(
buffer_name
.
find
(
"B_local"
)
==
std
::
string
::
npos
)
{
return
false
;
}
if
(
!
IsBLocal
(
store
->
buffer
))
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
// Must have exactly one index: B_local[index]
if
(
op
->
indices
.
size
()
!=
1
)
{
return
false
;
}
int64_t
old_extent
=
op
->
extent
.
as
<
IntImmNode
>
()
->
value
;
ICHECK
(
old_extent
%
expand_
==
0
)
<<
"Loop extent must be divisible by expand factor."
;
int64_t
new_extent
=
old_extent
/
expand_
;
// 修改循环范围
For
new_for
=
For
(
op
->
loop_var
,
op
->
min
,
Integer
(
new_extent
),
op
->
kind
,
MutateStore
(
store
,
op
->
loop_var
));
// Check if value is a BufferLoad from shared memory
const
BufferLoadNode
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
();
if
(
load
==
nullptr
)
{
return
false
;
return
new_for
;
}
// Check if load is from shared memory
std
::
string
load_buffer_name
=
load
->
buffer
->
name
;
std
::
cout
<<
"[DEBUG IsBLocalStorePattern] load buffer name: "
<<
load_buffer_name
<<
std
::
endl
;
if
(
load_buffer_name
.
find
(
"B_shared"
)
==
std
::
string
::
npos
)
{
return
false
;
bool
IsBLocal
(
const
Buffer
&
buffer
)
{
std
::
string
name
=
buffer
->
name
;
return
name
.
find
(
"B_local"
)
!=
std
::
string
::
npos
;
}
// Get buffer variables
*
local_var
=
op
->
buffer
->
data
;
*
shared_var
=
load
->
buffer
->
data
;
Stmt
MutateStore
(
const
BufferStoreNode
*
store
,
const
Var
&
loop_var
)
{
// Extract the shared memory offset from the load indices
if
(
!
load
->
indices
.
empty
())
{
*
shared_offset
=
load
->
indices
[
0
];
}
else
{
*
shared_offset
=
make_const
(
DataType
::
Int
(
32
),
0
);
}
Array
<
PrimExpr
>
new_indices
=
store
->
indices
;
return
true
;
}
PrimExpr
new_value
=
store
->
value
;
class
BLocalLayoutTransformer
:
public
StmtExprMutator
{
public:
BLocalLayoutTransformer
(
const
IRModule
&
module
)
:
module_
(
module
)
{}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
// Check if this is a B_local store pattern BEFORE visiting
// to get the original buffer->data vars (not mutated by VisitStmt_)
Var
local_var
;
Var
shared_var
;
PrimExpr
shared_offset
;
if
(
!
IsBLocalStorePattern
(
op
,
&
local_var
,
&
shared_var
,
&
shared_offset
))
{
// Only visit if not our target pattern
return
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
}
std
::
cout
<<
"[DEBUG BLocalLayoutTransformer VisitStmt_] BufferStoreNode buffer name: "
<<
op
->
buffer
->
name
<<
std
::
endl
;
// For ds_read_vector: ds_read_vector(dst, src, m, n, offset)
// m, n describe the 2D layout of the shared memory tile
// For B_local (16x32 tile): m=16, n=32
PrimExpr
m
=
make_const
(
DataType
::
Int
(
32
),
16
);
PrimExpr
n
=
make_const
(
DataType
::
Int
(
32
),
32
);
PrimExpr
offset
=
shared_offset
;
// Create the ds_read call
// ds_read_vector(local_ptr, shared_ptr, m, n, offset)
// Use the vars directly - don't call VisitExpr on them as that creates new Vars
Array
<
PrimExpr
>
ds_read_args
=
{
local_var
,
// dst: local buffer pointer
op
->
buffer
->
data
,
// src: shared memory pointer
m
,
// m: rows in shared memory tile
n
,
// n: columns in shared memory tile
offset
// offset: starting offset in shared memory
};
Call
ds_read_call
=
Call
(
DataType
::
Handle
(),
ds_read_vector
(),
ds_read_args
);
// Replace the BufferStore with the ds_read call
return
Evaluate
(
ds_read_call
);
}
// 修改切片跨度:
// 原来 j*vec : j*vec+vec
// 改为 j*vec : j*vec*expand + vec
private:
const
IRModule
&
module_
;
}
;
PrimExpr
idx
=
store
->
indices
[
0
];
//T.Ramp(j * 4, 1, 4) -> Ramp(j*8, 1, 4)
std
::
cout
<<
idx
<<
std
::
endl
;
/*!
* \brief Inject prefetch for B_local using ds_read_vector
*/
class
BLocalPrefetchInjector
:
public
StmtMutator
{
public:
BLocalPrefetchInjector
(
const
IRModule
&
module
)
:
module_
(
module
)
{}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
override
{
if
(
op
->
kind
==
ForKind
::
kParallel
||
op
->
kind
==
ForKind
::
kSerial
||
op
->
kind
==
ForKind
::
kVectorized
)
{
Stmt
body
=
VisitStmt
(
op
->
body
);
// Check if body contains B_local stores
if
(
ContainsBLocalStore
(
body
))
{
// Inject prefetch before the loop
Stmt
prefetch
=
GenerateBLocalPrefetch
();
return
SeqStmt
({
prefetch
,
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
)});
// 解析 j*vec 结构
// 假设结构为 j * vec + const
// 不改 RHS
// PrimExpr value = store->value;
// 修改写入向量宽度
// 原 value 是 Ramp(base=j*4, stride=1, lanes=4)
// 匹配 j * stride
// Ramp(base=j*8, stride=1, lanes=8)
if
(
const
auto
*
ramp
=
idx
.
as
<
RampNode
>
())
{
PrimExpr
base
=
ramp
->
base
;
PrimExpr
stride
=
ramp
->
stride
;
int
old_lanes
=
ramp
->
lanes
.
as
<
IntImmNode
>
()
->
value
;
int
new_lanes
=
old_lanes
*
expand_
;
// 匹配 base = j * stride_val
if
(
const
auto
*
mul
=
base
.
as
<
MulNode
>
())
{
if
(
mul
->
a
.
same_as
(
loop_var
))
{
int64_t
old_stride
=
mul
->
b
.
as
<
IntImmNode
>
()
->
value
;
int64_t
new_stride
=
old_stride
*
expand_
;
PrimExpr
new_base
=
loop_var
*
make_const
(
DataType
::
Int
(
32
),
new_stride
);
new_indices
.
Set
(
0
,
Ramp
(
new_base
,
stride
,
new_lanes
));
}
else
if
(
mul
->
b
.
same_as
(
loop_var
))
{
int64_t
old_stride
=
mul
->
a
.
as
<
IntImmNode
>
()
->
value
;
int64_t
new_stride
=
old_stride
*
expand_
;
PrimExpr
new_base
=
make_const
(
DataType
::
Int
(
32
),
new_stride
)
*
loop_var
;
new_indices
.
Set
(
0
,
Ramp
(
new_base
,
stride
,
new_lanes
));
}
}
}
return
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
);
if
(
auto
*
load
=
new_value
.
as
<
BufferLoadNode
>
())
{
// BufferLoad with region access: B_shared[start : end]
// end - start = lanes,需要同步扩展
Array
<
PrimExpr
>
value_indices
=
load
->
indices
;
if
(
auto
*
old_ramp
=
load
->
indices
[
0
].
as
<
RampNode
>
())
{
PrimExpr
scalar_base
=
old_ramp
->
base
;
// 必须是 scalar
PrimExpr
stride
=
old_ramp
->
stride
;
//RHS 4 lane
int
old_lanes
=
old_ramp
->
lanes
.
as
<
IntImmNode
>
()
->
value
;
//RHS 8 lane
int
new_lanes
=
old_lanes
*
expand_
;
value_indices
.
Set
(
0
,
Ramp
(
scalar_base
,
stride
,
new_lanes
)
);
new_value
=
BufferLoad
(
load
->
buffer
,
value_indices
);
}
}
return
StmtMutator
::
VisitStmt_
(
op
);
return
BufferStore
(
store
->
buffer
,
new_value
,
new_indices
);
}
};
private:
Stmt
GenerateBLocalPrefetch
()
{
// Placeholder: actual implementation depends on the specific
// shared memory layout and thread block configuration
return
Evaluate
(
0
);
}
const
IRModule
&
module_
;
};
Stmt
InjectBLocalLayoutTransformPass
(
Stmt
stmt
,
int
expand
)
{
return
BLocalLayoutTransformer
(
expand
)(
std
::
move
(
stmt
));
}
using
namespace
tir
::
transform
;
...
...
@@ -209,33 +208,16 @@ tvm::transform::Pass InjectBLocalLayoutTransform() {
}
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
BLocalLayoutTransform
er
(
m
)
(
n
->
body
);
n
->
body
=
Inject
BLocalLayoutTransform
Pass
(
n
->
body
,
2
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectBLocalLayoutTransform"
,
{});
}
tvm
::
transform
::
Pass
InjectBLocalLayoutTransformWithPrefetch
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
const
PassContext
&
ctx
)
{
// Only apply to DCU targets
if
(
!
IsDCUTarget
(
m
))
{
return
f
;
}
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
BLocalPrefetchInjector
(
m
)(
n
->
body
);
n
->
body
=
BLocalLayoutTransformer
(
m
)(
n
->
body
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectBLocalLayoutTransformWithPrefetch"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectBLocalLayoutTransform"
,
InjectBLocalLayoutTransform
);
refl
::
GlobalDef
().
def
(
"tl.transform.InjectBLocalLayoutTransformWithPrefetch"
,
InjectBLocalLayoutTransformWithPrefetch
);
}
}
// namespace tl
...
...
tilelang/engine/phase.py
View file @
dd95e41b
...
...
@@ -235,8 +235,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
ConfigIndexBitwidth
()(
mod
)
mod
=
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tilelang
.
transform
.
VectorizeLoop
(
enable_vectorize
=
allow_vectorize
(
pass_ctx
=
pass_ctx
))(
mod
)
# Transform B_local layout from shared memory thread-interleaved to local row-major
# mod = tilelang.transform.InjectBLocalLayoutTransform()(mod)
mod
=
tilelang
.
transform
.
StorageRewrite
()(
mod
)
mod
=
tir
.
transform
.
UnrollLoop
()(
mod
)
mod
=
tir
.
transform
.
RenormalizeSplitPattern
()(
mod
)
...
...
@@ -295,9 +293,10 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Transform threadblock to persistent threadblock
mod
=
tilelang
.
transform
.
PersistThreadblock
()(
mod
)
if
dcu_async_copy_supported
(
target
):
print
(
"--------------support dcu async copy------------------"
)
mod
=
tilelang
.
transform
.
LowerSharedGlobalCopy
()(
mod
)
mod
=
tilelang
.
transform
.
FixDCUWaitCount
()(
mod
)
#
mod
=
tilelang
.
transform
.
InjectBLocalLayoutTransform
()(
mod
)
# mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print
(
"OptimizeForTarget3"
)
print
(
mod
)
...
...
tilelang/transform/__init__.py
View file @
dd95e41b
...
...
@@ -386,20 +386,6 @@ def InjectBLocalLayoutTransform():
return
_ffi_api
.
InjectBLocalLayoutTransform
()
# type: ignore
def
InjectBLocalLayoutTransformWithPrefetch
():
"""Transform B_local layout with prefetch injection.
This pass is similar to InjectBLocalLayoutTransform but also injects
prefetch operations for B_local before the main transformation.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return
_ffi_api
.
InjectBLocalLayoutTransformWithPrefetch
()
# type: ignore
def
LowerDeviceStorageAccessInfo
():
"""Lower attached storage access information on device.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment