Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bd1c7b39
Unverified
Commit
bd1c7b39
authored
Oct 16, 2025
by
Yu Cheng
Committed by
GitHub
Oct 16, 2025
Browse files
[Refactor] Use `has_simt_copy` to decide whether to insert `set_max_nreg` (#982)
parent
8f001e02
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
6 deletions
+24
-6
examples/deepseek_v32/fp8_lighting_indexer.py
examples/deepseek_v32/fp8_lighting_indexer.py
+0
-2
src/transform/annotate_warp_group_reg_alloc.cc
src/transform/annotate_warp_group_reg_alloc.cc
+22
-3
tilelang/engine/phase.py
tilelang/engine/phase.py
+2
-1
No files found.
examples/deepseek_v32/fp8_lighting_indexer.py
View file @
bd1c7b39
...
...
@@ -136,8 +136,6 @@ def mqa_attn_return_logits(
cu_k_s_min
=
T
.
alloc_local
([
1
],
index_dtype
)
cu_k_e_max
=
T
.
alloc_local
([
1
],
index_dtype
)
T
.
no_set_max_nreg
()
cu_k_s_min
[
0
]
=
2147483647
cu_k_e_max
[
0
]
=
-
2147483648
...
...
src/transform/annotate_warp_group_reg_alloc.cc
View file @
bd1c7b39
...
...
@@ -59,6 +59,27 @@ private:
bool
warp_specialized_
=
false
;
};
class
SimtCopyDetector
:
public
StmtExprVisitor
{
public:
static
bool
Detect
(
const
Stmt
&
stmt
)
{
SimtCopyDetector
detector
;
detector
.
VisitStmt
(
stmt
);
return
detector
.
has_simt_copy_
;
}
private:
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
scope
=
runtime
::
StorageScope
::
Create
(
GetPtrStorageScope
(
op
->
buffer
->
data
));
if
(
scope
.
to_string
()
!=
"global"
)
{
has_simt_copy_
=
true
;
}
StmtExprVisitor
::
VisitStmt_
(
op
);
}
bool
has_simt_copy_
{
false
};
};
class
SetMaxNRegInjector
:
public
StmtExprMutator
{
public:
static
PrimFunc
Inject
(
PrimFunc
f
)
{
...
...
@@ -113,9 +134,7 @@ private:
auto
dec_reg_stmt
=
Evaluate
(
0
);
// Only inject if we have valid register hints and no SIMT copy
// For now, we assume no SIMT copy detection is available here
// TODO: Add SIMT copy detection if needed
bool
has_simt_copy
=
false
;
// Placeholder
bool
has_simt_copy
=
SimtCopyDetector
::
Detect
(
producer_body
);
if
(
dec_reg
>=
0
&&
inc_reg
>=
0
&&
!
has_simt_copy
)
{
auto
inc_reg_num
=
...
...
tilelang/engine/phase.py
View file @
bd1c7b39
...
...
@@ -135,7 +135,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
MultiVersionBuffer
()(
mod
)
mod
=
tilelang
.
transform
.
WarpSpecialized
()(
mod
)
mod
=
tilelang
.
transform
.
InjectTmaBarrier
()(
mod
)
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
# if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy
mod
=
tilelang
.
transform
.
PipelinePlanning
()(
mod
)
...
...
@@ -206,6 +205,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod
=
tilelang
.
transform
.
InjectPTXAsyncCopy
()(
mod
)
if
allow_tma_and_warp_specialized
(
pass_ctx
=
pass_ctx
,
target
=
target
):
mod
=
tilelang
.
transform
.
AnnotateWarpGroupRegAlloc
()(
mod
)
mod
=
tilelang
.
transform
.
MakePackedAPI
()(
mod
)
mod
=
tilelang
.
transform
.
LowerDeviceKernelLaunch
()(
mod
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment