Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
eff4082d
Commit
eff4082d
authored
May 06, 2026
by
wangziyang
Browse files
fix ds_read pass
parent
dd95e41b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
99 additions
and
117 deletions
+99
-117
src/op/builtin.cc
src/op/builtin.cc
+1
-1
src/target/codegen_hip.cc
src/target/codegen_hip.cc
+8
-10
src/tl_templates/dcu_hip/copy.h
src/tl_templates/dcu_hip/copy.h
+32
-24
src/transform/inject_ds_read.cc
src/transform/inject_ds_read.cc
+48
-73
tilelang/engine/phase.py
tilelang/engine/phase.py
+5
-0
tilelang/language/builtin.py
tilelang/language/builtin.py
+5
-9
No files found.
src/op/builtin.cc
View file @
eff4082d
...
@@ -218,7 +218,7 @@ TIR_DEFINE_TL_BUILTIN(tma_store_wait)
...
@@ -218,7 +218,7 @@ TIR_DEFINE_TL_BUILTIN(tma_store_wait)
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
ds_read_vector
)
TIR_DEFINE_TL_BUILTIN
(
ds_read_vector
)
.
set_num_inputs
(
5
)
.
set_num_inputs
(
3
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
...
...
src/target/codegen_hip.cc
View file @
eff4082d
...
@@ -839,18 +839,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
...
@@ -839,18 +839,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
func_name
+=
"_trans"
;
func_name
+=
"_trans"
;
print_extern_call_stmt
(
func_name
,
2
);
print_extern_call_stmt
(
func_name
,
2
);
}
else
if
(
op
->
op
.
same_as
(
tl
::
ds_read_vector
())){
}
else
if
(
op
->
op
.
same_as
(
tl
::
ds_read_vector
())){
//ds_read_b64 %1, %2 offset:%3
// ds_read_m32x16_b16 %0, %1 offset:0
// ds_read_m32x16_b16 %0, %1 offset:%2
printf
(
"[DEBUG VisitExpr_] Branch: ds_read_vector
\n
"
);
printf
(
"[DEBUG VisitExpr_] Branch: ds_read_vector
\n
"
);
std
::
string
dst
=
this
->
PrintExpr
(
op
->
args
[
0
]);
std
::
string
dst
=
this
->
PrintExpr
(
op
->
args
[
0
]);
std
::
string
lds_base_ptr
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
local_offset
=
this
->
PrintExpr
(
op
->
args
[
1
]);
std
::
string
m
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
lds_offset
=
this
->
PrintExpr
(
op
->
args
[
2
]);
std
::
string
n
=
this
->
PrintExpr
(
op
->
args
[
3
]);
os
<<
"tl::ds_read_vector("
std
::
string
offset
=
this
->
PrintExpr
(
op
->
args
[
4
]);
<<
dst
<<
" + "
<<
local_offset
this
->
PrintIndent
();
<<
", "
this
->
stream
<<
"tl::ds_read_vector<"
<<
m
<<
", "
<<
n
<<
", "
<<
offset
<<
">"
<<
lds_offset
<<
"(*reinterpret_cast<float4_*>("
<<
dst
<<
"), "
<<
")"
;
<<
"reinterpret_cast<uintptr_t>("
<<
lds_base_ptr
<<
"));
\n
"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
wait_wgmma
()))
{
}
else
if
(
op
->
op
.
same_as
(
tl
::
wait_wgmma
()))
{
printf
(
"[DEBUG VisitExpr_] Branch: wait_wgmma
\n
"
);
printf
(
"[DEBUG VisitExpr_] Branch: wait_wgmma
\n
"
);
this
->
PrintIndent
();
this
->
PrintIndent
();
...
...
src/tl_templates/dcu_hip/copy.h
View file @
eff4082d
...
@@ -176,36 +176,44 @@ template <int N>
...
@@ -176,36 +176,44 @@ template <int N>
// }
// }
// }
// }
template
<
int
M
,
int
N
,
int
offset
>
TL_DEVICE
void
ds_read_vector
(
void
*
dst
,
uint32_t
lds_base_ptr
)
TL_DEVICE
void
ds_read_vector
(
float4_
&
dst
,
uint32_t
lds_base_ptr
)
{
{
if
constexpr
(
M
==
16
&&
N
==
32
)
asm
volatile
(
"ds_read_m32x16_b16 %0, %1 offset:0
\n\t
"
{
const
int
offset_in_bytes
=
offset
*
sizeof
(
half_t
);
asm
volatile
(
"ds_read_m32x16_b16 %0, %1 offset:%2
\n\t
"
:
"+v"
(
dst
)
:
"+v"
(
dst
)
:
"v"
(
lds_base_ptr
),
:
"v"
(
lds_base_ptr
),
"n"
(
offset_in_bytes
)
:
"memory"
);
:
"memory"
);
}
else
if
constexpr
(
M
==
32
&&
N
==
16
)
{
const
int
offset_in_bytes0
=
offset
*
sizeof
(
half_t
);
const
int
offset_in_bytes1
=
offset_in_bytes0
+
4096
;
float2_
&
front
=
*
reinterpret_cast
<
float2_
*>
(
&
dst
);
float2_
&
rear
=
*
(
reinterpret_cast
<
float2_
*>
(
&
dst
)
+
1
);
asm
volatile
(
"ds_read_b64 %1, %2 offset:%3
\n\t
"
"ds_read_b64 %0, %2 offset:%4
\n\t
"
:
"+v"
(
rear
),
"+v"
(
front
)
:
"v"
(
lds_base_ptr
),
"n"
(
offset_in_bytes0
),
"n"
(
offset_in_bytes1
)
:
"memory"
);
}
}
}
// template <int M, int N, int offset>
// TL_DEVICE void ds_read_vector(float4_& dst, uint32_t lds_base_ptr)
// {
// if constexpr (M == 16 && N == 32)
// {
// const int offset_in_bytes = offset * sizeof(half_t);
// asm volatile("ds_read_m32x16_b16 %0, %1 offset:%2\n\t"
// : "+v"(dst)
// : "v"(lds_base_ptr),
// "n"(offset_in_bytes)
// : "memory");
// }
// else if constexpr (M == 32 && N == 16)
// {
// const int offset_in_bytes0 = offset * sizeof(half_t);
// const int offset_in_bytes1 = offset_in_bytes0 + 4096;
// float2_& front = *reinterpret_cast<float2_*>(&dst);
// float2_& rear = *(reinterpret_cast<float2_*>(&dst) + 1);
// asm volatile(
// "ds_read_b64 %1, %2 offset:%3\n\t"
// "ds_read_b64 %0, %2 offset:%4\n\t"
// : "+v"(rear), "+v"(front)
// : "v"(lds_base_ptr), "n"(offset_in_bytes0), "n"(offset_in_bytes1)
// : "memory"
// );
// }
// }
template
<
int
N
>
template
<
int
N
>
TL_DEVICE
void
cp_async_gs_conditional
(
void
*
lds_base_ptr
,
TL_DEVICE
void
cp_async_gs_conditional
(
void
*
lds_base_ptr
,
void
*
global_base_ptr
,
bool
cond
)
{
void
*
global_base_ptr
,
bool
cond
)
{
...
...
src/transform/inject_ds_read.cc
View file @
eff4082d
...
@@ -42,88 +42,63 @@ using namespace tir;
...
@@ -42,88 +42,63 @@ using namespace tir;
class
DSReadInjector
:
public
StmtExprMutator
{
class
DSReadInjector
:
public
StmtExprMutator
{
public:
public:
/*!
bool
IsBLocalBuffer
(
const
Buffer
&
buffer
)
{
* \brief Visit EvaluateNode to handle explicit ds_read_vector call
std
::
string
name
=
buffer
->
name
;
* ds_read_vector Call is wrapped in Evaluate to become a statement
return
name
.
find
(
"B_local"
)
!=
std
::
string
::
npos
;
* Parameters m, n, offset are passed explicitly via CallNode args
}
*/
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
override
{
private:
std
::
cout
<<
"[DEBUG VisitStmt_] Visiting EvaluateNode"
<<
std
::
endl
;
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
();
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
std
::
cout
<<
"[DEBUG VisitStmt_] CallNode ptr: "
<<
call
<<
std
::
endl
;
Buffer
buffer
=
op
->
buffer
;
if
(
call
!=
nullptr
&&
call
->
op
.
same_as
(
ds_read_vector
()))
{
ICHECK
(
call
->
args
.
size
()
==
5
)
if
(
!
IsBLocalBuffer
(
buffer
))
{
<<
"ds_read_vector expects 5 arguments: (dst, src, m, n, offset)"
;
// Print args for debugging - these are the actual CallNode args passed in
std
::
cout
<<
"[DEBUG ds_read_vector] args[0] (dst): "
<<
call
->
args
[
0
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[1] (src): "
<<
call
->
args
[
1
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[2] (m): "
<<
call
->
args
[
2
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[3] (n): "
<<
call
->
args
[
3
]
<<
std
::
endl
;
std
::
cout
<<
"[DEBUG ds_read_vector] args[4] (offset): "
<<
call
->
args
[
4
]
<<
std
::
endl
;
}
// Continue with default traversal (don't replace the existing call)
return
StmtExprMutator
::
VisitStmt_
(
op
);
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
}
/*!
const
BufferLoadNode
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
();
* \brief Visit BufferStoreNode to inject ds_read_vector call
if
(
!
load
)
{
* Pattern: local_buffer[...] = shared_buffer[...] (BufferLoad)
return
StmtExprMutator
::
VisitStmt_
(
op
);
* Parameters m, n, offset are passed via a CallNode (tl.ds_read_config)
}
*/
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
std
::
cout
<<
"[DEBUG VisitStmt_] Visiting BufferStoreNode"
<<
std
::
endl
;
// Check if the store is to a local register (not shared memory)
bool
is_local
=
op
->
buffer
.
scope
()
==
"local"
||
op
->
buffer
.
scope
()
==
"local.fragment"
;
std
::
cout
<<
"[DEBUG BufferStore] is_local: "
<<
is_local
<<
", scope: "
<<
op
->
buffer
.
scope
()
<<
std
::
endl
;
if
(
!
is_local
)
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
// Check if the value is a BufferLoad from shared memory
// local offset
const
BufferLoadNode
*
load
=
op
->
value
.
as
<
BufferLoadNode
>
();
ICHECK
(
op
->
indices
.
size
()
==
1
);
if
(
load
==
nullptr
)
{
PrimExpr
local_index
=
op
->
indices
[
0
];
std
::
cout
<<
"[DEBUG BufferStore] value is not BufferLoad"
<<
std
::
endl
;
PrimExpr
local_offset
;
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
if
(
const
RampNode
*
ramp
=
local_index
.
as
<
RampNode
>
())
{
local_offset
=
ramp
->
base
;
}
else
{
local_offset
=
local_index
;
}
bool
is_shared_load
=
load
->
buffer
.
scope
()
==
"shared"
||
// lds offset
load
->
buffer
.
scope
()
==
"shared.dyn"
;
ICHECK
(
load
->
indices
.
size
()
==
1
)
;
std
::
cout
<<
"[DEBUG BufferStore] is_shared_load: "
<<
is_shared_load
PrimExpr
lds_index
=
load
->
indices
[
0
];
<<
", load scope: "
<<
load
->
buffer
.
scope
()
<<
std
::
endl
;
PrimExpr
lds_offset
;
if
(
!
is_shared_load
)
{
if
(
const
RampNode
*
ramp
=
lds_index
.
as
<
RampNode
>
())
{
return
StmtExprMutator
::
VisitStmt_
(
op
);
lds_offset
=
ramp
->
base
;
}
}
else
{
lds_offset
=
lds_index
;
}
// For A_shared, use the actual shared memory base pointer
// buffer pointer
PrimExpr
m
=
make_const
(
DataType
::
Int
(
32
),
32
);
PrimExpr
buffer_ptr
=
buffer
->
data
;
PrimExpr
n
=
make_const
(
DataType
::
Int
(
32
),
16
);
PrimExpr
offset
;
Array
<
PrimExpr
>
args
=
{
// Extract the shared memory offset from the load indices
buffer_ptr
,
if
(
!
load
->
indices
.
empty
())
{
local_offset
,
offset
=
load
->
indices
[
0
];
lds_offset
}
else
{
};
offset
=
make_const
(
DataType
::
Int
(
32
),
0
);
}
Call
call
=
Call
(
DataType
::
Handle
(),
ds_read_vector
(),
args
);
// Use buffer data vars directly
return
Evaluate
(
call
);
Array
<
PrimExpr
>
new_args
=
{
load
->
buffer
->
data
,
// src
op
->
buffer
->
data
,
// dst
m
,
n
,
offset
};
// Create the ds_read call
Call
ds_read_call
=
Call
(
DataType
::
Handle
(),
ds_read_vector
(),
new_args
);
return
Evaluate
(
ds_read_call
);
}
}
};
};
...
...
tilelang/engine/phase.py
View file @
eff4082d
...
@@ -297,6 +297,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
...
@@ -297,6 +297,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
LowerSharedGlobalCopy
()(
mod
)
mod
=
tilelang
.
transform
.
LowerSharedGlobalCopy
()(
mod
)
mod
=
tilelang
.
transform
.
FixDCUWaitCount
()(
mod
)
mod
=
tilelang
.
transform
.
FixDCUWaitCount
()(
mod
)
mod
=
tilelang
.
transform
.
InjectBLocalLayoutTransform
()(
mod
)
mod
=
tilelang
.
transform
.
InjectBLocalLayoutTransform
()(
mod
)
print
(
"InjectBLocalLayoutTransform ............"
)
print
(
mod
)
mod
=
tilelang
.
transform
.
InjectDSRead
()(
mod
)
print
(
"InjectDSRead ............"
)
print
(
mod
)
# mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
# mod = tilelang.transform.SimplifyDCUAsyncCopy()(mod)
print
(
"OptimizeForTarget3"
)
print
(
"OptimizeForTarget3"
)
print
(
mod
)
print
(
mod
)
...
...
tilelang/language/builtin.py
View file @
eff4082d
...
@@ -90,7 +90,7 @@ def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = N
...
@@ -90,7 +90,7 @@ def __ldg(load_or_buf: BufferLoad | tir.Buffer, index: PrimExpr | int | None = N
def
ds_read_vector
(
dst
:
tir
.
Var
,
shared_ptr
:
tir
.
Var
,
m
:
int
,
n
:
int
,
offset
:
int
)
->
Call
:
def
ds_read_vector
(
dst
:
tir
.
Var
,
local_offset
:
tir
.
Var
,
shared_ptr
:
tir
.
Var
)
->
Call
:
"""
"""
Load from shared memory using ds_read_b64 instruction.
Load from shared memory using ds_read_b64 instruction.
...
@@ -104,14 +104,12 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in
...
@@ -104,14 +104,12 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in
Load from shared memory using ds_read_m32x16_b16 instruction.
Load from shared memory using ds_read_m32x16_b16 instruction.
The ds_read_vector intrinsic has signature:
The ds_read_vector intrinsic has signature:
ds_read_vector
<M,N,offset>(float4 & dst, int lds_base
_ptr)
ds_read_vector
(dst + local_offset, shared
_ptr)
Args:
Args:
dst: Destination pointer (register/local buffer).
dst: Destination pointer (register/local buffer).
local_offset: Local offset in bytes for the destination register.
lds_base_ptr: Source pointer (shared memory buffer data).
lds_base_ptr: Source pointer (shared memory buffer data).
M: Number of columns in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
N: Number of rows in the matrix to load (for ds_read_m32x16_b16 / ds_read_b64).
offset: address offset into shared memory.
Returns:
Returns:
Call: A TIR call intrinsic for the ds_read_b64 instruction.
Call: A TIR call intrinsic for the ds_read_b64 instruction.
...
@@ -120,10 +118,8 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in
...
@@ -120,10 +118,8 @@ def ds_read_vector(dst: tir.Var, shared_ptr: tir.Var, m: int, n: int, offset: in
"handle"
,
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ds_read_vector"
),
tir
.
op
.
Op
.
get
(
"tl.ds_read_vector"
),
dst
,
dst
,
shared_ptr
,
local_offset
,
m
,
shared_ptr
n
,
offset
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment