Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
15599a93
Commit
15599a93
authored
Apr 03, 2026
by
wangziyang
Browse files
print MatrixCore init local size
parent
3852d58b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
83 additions
and
107 deletions
+83
-107
src/transform/inject_ds_read.cc
src/transform/inject_ds_read.cc
+73
-106
tilelang/engine/phase.py
tilelang/engine/phase.py
+6
-0
tilelang/intrinsics/mfma_layout.py
tilelang/intrinsics/mfma_layout.py
+2
-0
tilelang/intrinsics/mfma_macro_generator.py
tilelang/intrinsics/mfma_macro_generator.py
+2
-1
No files found.
src/transform/inject_ds_read.cc
View file @
15599a93
...
...
@@ -21,6 +21,7 @@
* \brief Replace shared memory BufferLoad with ds_read hardware instructions
* \file inject_ds_read.cc
*/
#include <iostream>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
...
...
@@ -57,138 +58,104 @@ bool IsDCUTarget(const IRModule& module) {
return
false
;
}
class
DSReadInjector
:
public
StmtMutator
{
class
DSReadInjector
:
public
Stmt
Expr
Mutator
{
public:
Stmt
VisitStmt_
(
const
BufferStoreNode
*
store
)
final
{
/*!
* \brief Visit EvaluateNode to handle explicit ds_read_vector call
* ds_read_vector Call is wrapped in Evaluate to become a statement
* Parameters m, n, offset are passed explicitly via CallNode args
*/
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
override
{
std
::
cout
<<
"[DEBUG VisitStmt_] Visiting EvaluateNode"
<<
std
::
endl
;
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
();
std
::
cout
<<
"[DEBUG VisitStmt_] CallNode ptr: "
<<
call
<<
std
::
endl
;
if
(
call
!=
nullptr
&&
call
->
op
.
same_as
(
ds_read_vector
()))
{
ICHECK
(
call
->
args
.
size
()
==
5
)
<<
"ds_read_vector expects 5 arguments: (dst, src, m, n, offset)"
;
// Print args for debugging - these are the actual CallNode args passed in
std
::
cout
<<
"[DEBUG ds_read_vector] args[0] (dst): "
<<
call
->
args
[
0
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[1] (src): "
<<
call
->
args
[
1
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[2] (m): "
<<
call
->
args
[
2
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[3] (n): "
<<
call
->
args
[
3
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[4] (offset): "
<<
call
->
args
[
4
]
<<
std
::
endl
;
}
// Continue with default traversal (don't replace the existing call)
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
/*!
* \brief Visit BufferStoreNode to inject ds_read_vector call
* Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad)
* Parameters m, n, offset are passed via a preceding CallNode
*/
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
std
::
cout
<<
"[DEBUG VisitStmt_] Visiting BufferStoreNode"
<<
std
::
endl
;
// Check if the store is to a local register (not shared memory)
bool
is_local
=
store
->
buffer
.
scope
()
==
"local"
||
store
->
buffer
.
scope
()
==
"local.fragment"
;
bool
is_local
=
op
->
buffer
.
scope
()
==
"local"
||
op
->
buffer
.
scope
()
==
"local.fragment"
;
std
::
cout
<<
"[DEBUG BufferStore] is_local: "
<<
is_local
<<
", scope: "
<<
op
->
buffer
.
scope
()
<<
std
::
endl
;
if
(
!
is_local
)
{
return
StmtMutator
::
VisitStmt_
(
store
);
return
Stmt
Expr
Mutator
::
VisitStmt_
(
op
);
}
// Check if the value is a BufferLoad from shared memory
if
(
auto
*
load
=
store
->
value
.
as
<
BufferLoadNode
>
())
{
bool
is_shared_load
=
load
->
buffer
.
scope
()
==
"shared"
||
load
->
buffer
.
scope
()
==
"shared.dyn"
;
if
(
!
is_shared_load
)
{
return
StmtMutator
::
VisitStmt_
(
store
);
}
// Skip if indices are vectorized (contain Ramp expressions)
// ds_read is a scalar instruction, cannot handle vectorized indices
if
(
HasVectorizedIndices
(
store
->
indices
)
||
HasVectorizedIndices
(
load
->
indices
))
{
return
StmtMutator
::
VisitStmt_
(
store
);
}
// Check if the buffer is large enough for ds_read_vector
// ds_read_vector<32, 16> with half_t reads 16 bytes (8 elements)
// For small buffers (less than 16 bytes), skip this transformation
if
(
store
->
buffer
.
defined
())
{
const
auto
&
buffer_shape
=
store
->
buffer
->
shape
;
if
(
buffer_shape
.
size
()
==
1
)
{
if
(
auto
*
int_shape
=
buffer_shape
[
0
].
as
<
IntImmNode
>
())
{
int
extent
=
int_shape
->
value
;
int
dtype_bytes
=
load
->
dtype
.
bytes
();
// ds_read_vector<32,16> with half_t reads 16 bytes minimum
// For buffers smaller than what ds_read_vector needs, skip
if
(
extent
*
dtype_bytes
<
16
)
{
return
StmtMutator
::
VisitStmt_
(
store
);
}
}
}
}
// Analyze the load pattern to determine which ds_read to use
return
InjectDSRead
(
store
,
load
);
const
BufferLoadNode
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
();
if
(
load
==
nullptr
)
{
std
::
cout
<<
"[DEBUG BufferStore] value is not BufferLoad"
<<
std
::
endl
;
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
return
StmtMutator
::
VisitStmt_
(
store
);
}
private:
bool
is_shared_load
=
load
->
buffer
.
scope
()
==
"shared"
||
load
->
buffer
.
scope
()
==
"shared.dyn"
;
std
::
cout
<<
"[DEBUG BufferStore] is_shared_load: "
<<
is_shared_load
<<
", load scope: "
<<
load
->
buffer
.
scope
()
<<
std
::
endl
;
// PrimExpr VisitExpr_(const CallNode *op) {
// Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
// if (call->op.same_as(builtin::tvm_access_ptr())) {
// return RewriteBufferAccess(call, {1});
// }
// return call;
// }
/*!
* \brief Check if any index expression contains a Ramp (vectorized) expression
*/
bool
HasVectorizedIndices
(
const
Array
<
PrimExpr
>&
indices
)
{
for
(
const
auto
&
idx
:
indices
)
{
if
(
idx
.
as
<
RampNode
>
())
{
return
true
;
}
}
return
false
;
}
Stmt
InjectDSRead
(
const
BufferStoreNode
*
store
,
const
BufferLoadNode
*
load
)
{
const
Buffer
&
shared_buf
=
load
->
buffer
;
const
Buffer
&
local_buf
=
store
->
buffer
;
// Analyze indices to determine the byte offset
// PrimExpr offset = load->indices.size() > 0 ? load->indices[0] : make_zero(DataType::UInt(0));
// Calculate buffer size in bytes
int
buffer_bytes
=
0
;
if
(
local_buf
.
defined
()
&&
local_buf
->
shape
.
size
()
==
1
)
{
if
(
auto
*
int_shape
=
local_buf
->
shape
[
0
].
as
<
IntImmNode
>
())
{
int
num_elements
=
int_shape
->
value
;
int
dtype_bytes
=
local_buf
->
dtype
.
bytes
();
buffer_bytes
=
num_elements
*
dtype_bytes
;
}
if
(
!
is_shared_load
)
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
// Determine which ds_read to use based on buffer size
// ds_read_b64 loads 8 bytes (64 bits) = 1 element for half_t, 2 for float32
// ds_read_m32x16_b16 loads 32 bytes (256 bits)
int
dtype_bits
=
local_buf
->
dtype
.
bits
();
int
m
=
16
;
// For buffer < 16 bytes, use single ds_read_b64 (M=32, N=1)
// For buffer >= 16 bytes, use double ds_read_b64 (M=32, N=16)
// ds_read_b64 reads 8 bytes per call
int
n
=
(
buffer_bytes
>=
32
)
?
32
:
16
;
int
offset
=
0
;
return
EmitDSRead
(
local_buf
,
shared_buf
,
m
,
n
,
offset
);
}
Stmt
EmitDSRead
(
const
Buffer
&
local_buf
,
const
Buffer
&
shared_buf
,
int
m
,
int
n
,
int
offset
)
{
// ds_read_vector takes: (dst, shared_ptr, m, n, offset)
Array
<
PrimExpr
>
args
=
{
local_buf
->
data
,
// dst: local buffer data pointer
shared_buf
.
access_ptr
(
0
,
DataType
::
Handle
(),
1
,
0
),
// src: shared buffer data pointer
make_const
(
DataType
::
Int
(
32
),
m
),
make_const
(
DataType
::
Int
(
32
),
n
),
make_const
(
DataType
::
Int
(
32
),
offset
)
// byte_offset: offset into shared memory
// Found pattern: local = BufferLoad(shared)
// The m, n, offset parameters should come from a CallNode in the IR
// For now, use default values that will be replaced when CallNode is processed
std
::
cout
<<
"[DEBUG BufferStore] Injecting ds_read_vector call!"
<<
std
::
endl
;
// Get parameters from the Store's indices or use default values
// In a full implementation, these would come from a preceding CallNode
PrimExpr
m
=
make_const
(
DataType
::
Int
(
32
),
16
);
PrimExpr
n
=
make_const
(
DataType
::
Int
(
32
),
16
);
PrimExpr
offset
=
make_const
(
DataType
::
Int
(
32
),
0
);
// Visit all arguments to transform any nested expressions
Array
<
PrimExpr
>
new_args
=
{
VisitExpr
(
load
->
buffer
.
access_ptr
(
0
,
DataType
::
Handle
(),
1
,
0
)),
// src
VisitExpr
(
op
->
buffer
->
data
),
// dst
VisitExpr
(
m
),
VisitExpr
(
n
),
VisitExpr
(
offset
)
};
Stmt
ds_read_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
ds_read_vector
(),
args
));
return
ds_read_stmt
;
// Create the ds_read call
Call
ds_read_call
=
Call
(
DataType
::
Handle
(),
ds_read_vector
(),
new_args
);
return
Evaluate
(
ds_read_call
);
}
};
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
InjectDSRead
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
const
PassContext
&
ctx
)
{
std
::
cout
<<
"[DEBUG InjectDSRead] Pass is being executed"
<<
std
::
endl
;
// Only apply to DCU targets
if
(
!
IsDCUTarget
(
m
))
{
std
::
cout
<<
"[DEBUG InjectDSRead] Not a DCU target, skipping"
<<
std
::
endl
;
return
f
;
}
std
::
cout
<<
"[DEBUG InjectDSRead] Is DCU target, applying injector"
<<
std
::
endl
;
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
DSReadInjector
()(
n
->
body
);
return
f
;
...
...
tilelang/engine/phase.py
View file @
15599a93
...
...
@@ -181,6 +181,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def
OptimizeForTarget
(
mod
:
IRModule
,
target
:
Target
)
->
IRModule
:
# print("********************")
# print(mod)
# print("********************")
pass_ctx
=
tilelang
.
transform
.
get_pass_context
()
# Lower the barrier.arrive into specific initialization slot
mod
=
tilelang
.
transform
.
LowerSharedBarrier
()(
mod
)
...
...
@@ -204,6 +207,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
RewriteWgmmaSync
()(
mod
)
mod
=
tilelang
.
transform
.
InjectFenceProxy
()(
mod
)
else
:
mod
=
tilelang
.
transform
.
IfStmtBinding
()(
mod
)
mod
=
tilelang
.
transform
.
PlanAndUpdateBufferAllocationLocation
()(
mod
)
mod
=
tilelang
.
transform
.
PipelinePlanning
()(
mod
)
...
...
@@ -214,12 +218,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# so we need to inject a fence proxy before it
mod
=
tilelang
.
transform
.
InjectFenceProxy
()(
mod
)
mod
=
tilelang
.
transform
.
LowerOpaqueBlock
()(
mod
)
mod
=
tilelang
.
transform
.
Simplify
()(
mod
)
mod
=
tir
.
transform
.
NarrowDataType
(
32
)(
mod
)
mod
=
tilelang
.
transform
.
FlattenBuffer
()(
mod
)
# ConfigIndexBitwidth must be applied after FlattenBuffer
# as it will flatten index computing
mod
=
tilelang
.
transform
.
ConfigIndexBitwidth
()(
mod
)
mod
=
tir
.
transform
.
Simplify
()(
mod
)
mod
=
tilelang
.
transform
.
VectorizeLoop
(
enable_vectorize
=
allow_vectorize
(
pass_ctx
=
pass_ctx
))(
mod
)
...
...
tilelang/intrinsics/mfma_layout.py
View file @
15599a93
...
...
@@ -38,6 +38,7 @@ def shared_16x16_to_ldmatrix_64x4_layout(ind):
def
thread_id_shared_access_64x4_to_16x16_layout_A
(
thread_id
,
local_id
):
print
(
"mfma_layout thread_id_shared_access_64x4_to_16x16_layout_A:"
,
thread_id
,
local_id
)
i
=
thread_id
%
16
j
=
(
thread_id
//
16
)
*
4
+
local_id
return
i
,
j
...
...
@@ -50,6 +51,7 @@ def shared_16x16_to_local_64x4_layout_A(i, j):
def
thread_id_shared_access_64x4_to_16x16_layout_B
(
thread_id
,
local_id
):
print
(
"mfma_layout thread_id_shared_access_64x4_to_16x16_layout_B:"
,
thread_id
,
local_id
)
i
=
local_id
+
(
thread_id
//
16
)
*
4
j
=
thread_id
%
16
return
i
,
j
...
...
tilelang/intrinsics/mfma_macro_generator.py
View file @
15599a93
...
...
@@ -115,7 +115,8 @@ class MatrixCoreIntrinEmitter:
if
a_dtype
.
bits
==
32
:
self
.
k_dim
=
4
elif
a_dtype
.
bits
in
{
16
,
8
}:
self
.
k_dim
=
16
# self.k_dim = 16
self
.
k_dim
=
256
else
:
raise
ValueError
(
f
"Unsupported a_dtype =
{
a_dtype
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment