Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
74e57416
Commit
74e57416
authored
Apr 09, 2026
by
wangziyang
Browse files
add inject_blocal_layout
parent
15599a93
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
380 additions
and
40 deletions
+380
-40
src/transform/inject_blocal_layout_transform.cc
src/transform/inject_blocal_layout_transform.cc
+242
-0
src/transform/inject_ds_read.cc
src/transform/inject_ds_read.cc
+11
-35
src/transform/inject_utils.cc
src/transform/inject_utils.cc
+38
-0
src/transform/inject_utils.h
src/transform/inject_utils.h
+42
-0
tilelang/engine/phase.py
tilelang/engine/phase.py
+6
-3
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+1
-2
tilelang/transform/__init__.py
tilelang/transform/__init__.py
+40
-0
No files found.
src/transform/inject_blocal_layout_transform.cc
0 → 100644
View file @
74e57416
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_blocal_layout_transform.cc
* \brief Transform B_local layout from shared memory thread-interleaved layout
* to local row-major layout using ds_read_vector hardware instructions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
#include "inject_utils.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
/*!
* \brief Check if a statement contains B_local stores
*/
bool
ContainsBLocalStore
(
const
Stmt
&
stmt
)
{
bool
found
=
false
;
tir
::
PreOrderVisit
(
stmt
,
[
&
](
const
ObjectRef
&
node
)
->
bool
{
if
(
found
)
{
return
false
;
}
if
(
const
auto
*
store
=
node
.
as
<
BufferStoreNode
>
())
{
std
::
string
name
=
store
->
buffer
->
name
;
if
(
name
.
find
(
"B_local"
)
!=
std
::
string
::
npos
)
{
found
=
true
;
return
false
;
}
}
return
true
;
});
return
found
;
}
/*!
* \brief Check if this is a B_local store pattern
*
* Pattern to match:
* B_local[index] = B_shared[index_expr]
*
* Where B_shared[index_expr] is a complex expression involving:
* - thread_binding (threadIdx.x, threadIdx.y, etc.)
* - ki (iteration variable)
* - j and local_id (loop variables)
*/
bool
IsBLocalStorePattern
(
const
BufferStoreNode
*
op
,
Var
*
local_var
,
Var
*
shared_var
,
PrimExpr
*
shared_offset
)
{
// Check if store is to a local buffer named B_local
std
::
string
buffer_name
=
op
->
buffer
->
name
;
if
(
buffer_name
.
find
(
"B_local"
)
==
std
::
string
::
npos
)
{
return
false
;
}
// Must have exactly one index: B_local[index]
if
(
op
->
indices
.
size
()
!=
1
)
{
return
false
;
}
// Check if value is a BufferLoad from shared memory
const
BufferLoadNode
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
();
if
(
load
==
nullptr
)
{
return
false
;
}
// Check if load is from shared memory
std
::
string
load_buffer_name
=
load
->
buffer
->
name
;
std
::
cout
<<
"[DEBUG IsBLocalStorePattern] load buffer name: "
<<
load_buffer_name
<<
std
::
endl
;
if
(
load_buffer_name
.
find
(
"B_shared"
)
==
std
::
string
::
npos
)
{
return
false
;
}
// Get buffer variables
*
local_var
=
op
->
buffer
->
data
;
*
shared_var
=
load
->
buffer
->
data
;
// Extract the shared memory offset from the load indices
if
(
!
load
->
indices
.
empty
())
{
*
shared_offset
=
load
->
indices
[
0
];
}
else
{
*
shared_offset
=
make_const
(
DataType
::
Int
(
32
),
0
);
}
return
true
;
}
class
BLocalLayoutTransformer
:
public
StmtExprMutator
{
public:
BLocalLayoutTransformer
(
const
IRModule
&
module
)
:
module_
(
module
)
{}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
// Check if this is a B_local store pattern BEFORE visiting
// to get the original buffer->data vars (not mutated by VisitStmt_)
Var
local_var
;
Var
shared_var
;
PrimExpr
shared_offset
;
if
(
!
IsBLocalStorePattern
(
op
,
&
local_var
,
&
shared_var
,
&
shared_offset
))
{
// Only visit if not our target pattern
return
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
}
std
::
cout
<<
"[DEBUG BLocalLayoutTransformer VisitStmt_] BufferStoreNode buffer name: "
<<
op
->
buffer
->
name
<<
std
::
endl
;
// For ds_read_vector: ds_read_vector(dst, src, m, n, offset)
// m, n describe the 2D layout of the shared memory tile
// For B_local (16x32 tile): m=16, n=32
PrimExpr
m
=
make_const
(
DataType
::
Int
(
32
),
16
);
PrimExpr
n
=
make_const
(
DataType
::
Int
(
32
),
32
);
PrimExpr
offset
=
shared_offset
;
// Create the ds_read call
// ds_read_vector(local_ptr, shared_ptr, m, n, offset)
// Use the vars directly - don't call VisitExpr on them as that creates new Vars
Array
<
PrimExpr
>
ds_read_args
=
{
local_var
,
// dst: local buffer pointer
op
->
buffer
->
data
,
// src: shared memory pointer
m
,
// m: rows in shared memory tile
n
,
// n: columns in shared memory tile
offset
// offset: starting offset in shared memory
};
Call
ds_read_call
=
Call
(
DataType
::
Handle
(),
ds_read_vector
(),
ds_read_args
);
// Replace the BufferStore with the ds_read call
return
Evaluate
(
ds_read_call
);
}
private:
const
IRModule
&
module_
;
};
/*!
* \brief Inject prefetch for B_local using ds_read_vector
*/
class
BLocalPrefetchInjector
:
public
StmtMutator
{
public:
BLocalPrefetchInjector
(
const
IRModule
&
module
)
:
module_
(
module
)
{}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
override
{
if
(
op
->
kind
==
ForKind
::
kParallel
||
op
->
kind
==
ForKind
::
kSerial
||
op
->
kind
==
ForKind
::
kVectorized
)
{
Stmt
body
=
VisitStmt
(
op
->
body
);
// Check if body contains B_local stores
if
(
ContainsBLocalStore
(
body
))
{
// Inject prefetch before the loop
Stmt
prefetch
=
GenerateBLocalPrefetch
();
return
SeqStmt
({
prefetch
,
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
)});
}
return
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
);
}
return
StmtMutator
::
VisitStmt_
(
op
);
}
private:
Stmt
GenerateBLocalPrefetch
()
{
// Placeholder: actual implementation depends on the specific
// shared memory layout and thread block configuration
return
Evaluate
(
0
);
}
const
IRModule
&
module_
;
};
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
InjectBLocalLayoutTransform
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
const
PassContext
&
ctx
)
{
// Only apply to DCU targets
if
(
!
IsDCUTarget
(
m
))
{
std
::
cout
<<
"[DEBUG InjectBLocalLayoutTransform] Not a DCU target, skipping"
<<
std
::
endl
;
return
f
;
}
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
BLocalLayoutTransformer
(
m
)(
n
->
body
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectBLocalLayoutTransform"
,
{});
}
tvm
::
transform
::
Pass
InjectBLocalLayoutTransformWithPrefetch
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
const
PassContext
&
ctx
)
{
// Only apply to DCU targets
if
(
!
IsDCUTarget
(
m
))
{
return
f
;
}
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
BLocalPrefetchInjector
(
m
)(
n
->
body
);
n
->
body
=
BLocalLayoutTransformer
(
m
)(
n
->
body
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectBLocalLayoutTransformWithPrefetch"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectBLocalLayoutTransform"
,
InjectBLocalLayoutTransform
);
refl
::
GlobalDef
().
def
(
"tl.transform.InjectBLocalLayoutTransformWithPrefetch"
,
InjectBLocalLayoutTransformWithPrefetch
);
}
}
// namespace tl
}
// namespace tvm
src/transform/inject_ds_read.cc
View file @
74e57416
...
...
@@ -33,31 +33,13 @@
#include "../op/builtin.h"
#include "tir/ir/buffer_common.h"
#include "tvm/tir/stmt.h"
#include "inject_utils.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
/*!
* \brief Check if the target is AMD DCU (gfx936, gfx942, etc.)
*/
bool
IsDCUTarget
(
const
IRModule
&
module
)
{
for
(
auto
&
p
:
module
->
functions
)
{
if
(
auto
*
prim_func
=
p
.
second
.
as
<
PrimFuncNode
>
())
{
if
(
auto
opt_target
=
prim_func
->
GetAttr
<
Target
>
(
"target"
))
{
Target
target
=
opt_target
.
value
();
if
(
target
->
attrs
.
count
(
"mcpu"
))
{
std
::
string
mcpu
=
Downcast
<
tvm
::
ffi
::
String
>
(
target
->
attrs
.
at
(
"mcpu"
));
// if mcpu start with "gfx936", it is DCU
return
mcpu
.
find
(
"gfx936"
)
==
0
;
}
}
}
}
return
false
;
}
class
DSReadInjector
:
public
StmtExprMutator
{
public:
/*!
...
...
@@ -87,7 +69,7 @@ class DSReadInjector : public StmtExprMutator {
/*!
* \brief Visit BufferStoreNode to inject ds_read_vector call
* Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad)
* Parameters m, n, offset are passed via a
preceding CallNode
* Parameters m, n, offset are passed via a
CallNode (tl.ds_read_config)
*/
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
std
::
cout
<<
"[DEBUG VisitStmt_] Visiting BufferStoreNode"
<<
std
::
endl
;
...
...
@@ -118,24 +100,18 @@ class DSReadInjector : public StmtExprMutator {
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
// Found pattern: local = BufferLoad(shared)
// The m, n, offset parameters should come from a CallNode in the IR
// For now, use default values that will be replaced when CallNode is processed
std
::
cout
<<
"[DEBUG BufferStore] Injecting ds_read_vector call!"
<<
std
::
endl
;
// Get parameters from the Store's indices or use default values
// In a full implementation, these would come from a preceding CallNode
PrimExpr
m
=
make_const
(
DataType
::
Int
(
32
),
16
);
// For A_shared, use the actual shared memory base pointer
PrimExpr
m
=
make_const
(
DataType
::
Int
(
32
),
32
);
PrimExpr
n
=
make_const
(
DataType
::
Int
(
32
),
16
);
PrimExpr
offset
=
make_const
(
DataType
::
Int
(
32
),
0
);
//
Visit all arguments to transform any nested expressions
//
Use buffer data vars directly
Array
<
PrimExpr
>
new_args
=
{
VisitExpr
(
load
->
buffer
.
access_ptr
(
0
,
DataType
::
Handle
(),
1
,
0
)),
// src
VisitExpr
(
op
->
buffer
->
data
),
// dst
VisitExpr
(
m
)
,
VisitExpr
(
n
)
,
VisitExpr
(
offset
)
load
->
buffer
->
data
,
// src
op
->
buffer
->
data
,
// dst
m
,
n
,
offset
};
// Create the ds_read call
...
...
src/transform/inject_utils.cc
0 → 100644
View file @
74e57416
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_utils.cc
* \brief Common utilities for injection transforms.
*/
#include "inject_utils.h"
#include "../target/utils.h"
#include <iostream>
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
bool
IsDCUTarget
(
const
IRModule
&
module
)
{
return
TargetIsDCU
(
Target
::
Current
(
false
));
}
}
// namespace tl
}
// namespace tvm
src/transform/inject_utils.h
0 → 100644
View file @
74e57416
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_utils.h
* \brief Common utilities for injection transforms.
*/
#ifndef TVM_TL_INJECT_UTILS_H_
#define TVM_TL_INJECT_UTILS_H_
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
namespace
tvm
{
namespace
tl
{
/*!
* \brief Check if the target is AMD DCU (gfx936, gfx942, etc.)
*/
bool
IsDCUTarget
(
const
IRModule
&
module
);
}
// namespace tl
}
// namespace tvm
#endif // TVM_TL_INJECT_UTILS_H_
tilelang/engine/phase.py
View file @
74e57416
...
...
@@ -181,9 +181,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def
OptimizeForTarget
(
mod
:
IRModule
,
target
:
Target
)
->
IRModule
:
# print("********************")
# print(mod)
# print("********************")
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
# Lower the barrier.arrive into specific initialization slot
mod
=
tilelang
.
transform
.
LowerSharedBarrier
()(
mod
)
...
...
@@ -229,6 +227,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
ConfigIndexBitwidth
()(
mod
)
mod
=
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tilelang
.
transform
.
VectorizeLoop
(
enable_vectorize
=
allow_vectorize
(
pass_ctx
=
pass_ctx
))(
mod
)
# Transform B_local layout from shared memory thread-interleaved to local row-major
mod
=
tilelang
.
transform
.
InjectBLocalLayoutTransform
()(
mod
)
mod
=
tilelang
.
transform
.
StorageRewrite
()(
mod
)
mod
=
tir
.
transform
.
UnrollLoop
()(
mod
)
mod
=
tir
.
transform
.
RenormalizeSplitPattern
()(
mod
)
...
...
@@ -265,11 +265,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
MergeSharedMemoryAllocations
(
enable_aggressive_merge
=
enable_aggressive_merge
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared"
)(
mod
)
mod
=
tilelang
.
transform
.
ThreadSync
(
"shared.dyn"
)(
mod
)
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
# Inject ds_read for shared to register memory copy on DCU
mod
=
tilelang
.
transform
.
InjectDSRead
()(
mod
)
print
(
mod
)
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
mod
=
tilelang
.
transform
.
MakePackedAPI
()(
mod
)
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
74e57416
...
...
@@ -115,8 +115,7 @@ class MatrixCoreIntrinEmitter:
if
a_dtype
.
bits
==
32
:
self
.
k_dim
=
4
elif
a_dtype
.
bits
in
{
16
,
8
}:
# self.k_dim = 16
self
.
k_dim
=
256
self
.
k_dim
=
16
else
:
raise
ValueError
(
f
"Unsupported a_dtype =
{
a_dtype
}
"
)
...
...
tilelang/transform/__init__.py
View file @
74e57416
...
...
@@ -360,6 +360,46 @@ def InjectDSRead():
return
_ffi_api
.
InjectDSRead
()
# type: ignore
def
InjectBLocalLayoutTransform
():
"""Transform B_local layout from shared memory thread-interleaved to local row-major.
This pass specifically handles the B_local buffer layout transformation in GEMM kernels
for AMD DCU (gfx936, gfx942, etc.). It converts complex indexed BufferStore patterns
from shared memory into vectorized ds_read_vector hardware instructions.
B Layout Transformation:
- Shared Memory Layout (per thread in warp, 16 elements):
Thread 0: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15]
Thread 1: [16,17,18,... ,31 ]
...
- Local Register Layout (16x32, row-major):
Row 0: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15, 0, 1,...]
Row 1: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15, 0, 1,...]
...
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
print
(
"Injecting Transform B_local layout from shared memory thread-interleaved to local row-major for DCU."
)
return
_ffi_api
.
InjectBLocalLayoutTransform
()
# type: ignore
def
InjectBLocalLayoutTransformWithPrefetch
():
"""Transform B_local layout with prefetch injection.
This pass is similar to InjectBLocalLayoutTransform but also injects
prefetch operations for B_local before the main transformation.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return
_ffi_api
.
InjectBLocalLayoutTransformWithPrefetch
()
# type: ignore
def
LowerDeviceStorageAccessInfo
():
"""Lower attached storage access information on device.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment