Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
0ff4f427
Unverified
Commit
0ff4f427
authored
Oct 16, 2025
by
Yuqi Dong
Committed by
GitHub
Oct 16, 2025
Browse files
[Feature]: Add test for atomicadd auto vectorize and remove useless code (#1019)
* update * format * rabbit
parent
bd1c7b39
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
99 additions
and
87 deletions
+99
-87
src/op/atomic_add.cc
src/op/atomic_add.cc
+2
-7
src/op/builtin.cc
src/op/builtin.cc
+5
-0
src/op/builtin.h
src/op/builtin.h
+7
-0
src/transform/atomicadd_vectorize.cc
src/transform/atomicadd_vectorize.cc
+84
-80
src/transform/atomicadd_vectorize.h
src/transform/atomicadd_vectorize.h
+1
-0
No files found.
src/op/atomic_add.cc
View file @
0ff4f427
...
@@ -272,7 +272,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
...
@@ -272,7 +272,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
PrimExpr
dst_predicate
=
MakePredicate
(
analyzer
,
loop_vars
,
dst
->
shape
,
1
);
PrimExpr
dst_predicate
=
MakePredicate
(
analyzer
,
loop_vars
,
dst
->
shape
,
1
);
Array
<
PrimExpr
>
new_args
;
Array
<
PrimExpr
>
new_args
;
new_args
.
push_back
(
StringImm
(
"AtomicAdd"
));
PrimExpr
src_value
=
BufferLoad
(
src
,
src_indices
);
PrimExpr
src_value
=
BufferLoad
(
src
,
src_indices
);
if
(
src
->
dtype
!=
dst
->
dtype
)
if
(
src
->
dtype
!=
dst
->
dtype
)
...
@@ -288,7 +287,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
...
@@ -288,7 +287,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
new_args
.
push_back
(
src_value
);
new_args
.
push_back
(
src_value
);
Call
atomicadd_call
=
Call
atomicadd_call
=
tvm
::
tir
::
Call
(
dst
->
dtype
,
builtin
::
call_extern
(),
new_args
);
tvm
::
tir
::
Call
(
dst
->
dtype
,
atomicadd_elem_op
(),
new_args
);
Stmt
body
=
tvm
::
tir
::
Evaluate
(
atomicadd_call
);
Stmt
body
=
tvm
::
tir
::
Evaluate
(
atomicadd_call
);
...
@@ -325,10 +324,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
...
@@ -325,10 +324,6 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
*/
*/
LayoutMap
AtomicAddNode
::
InferLayout
(
const
LayoutInferArgs
&
T
,
LayoutMap
AtomicAddNode
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
{
InferLevel
level
)
const
{
if
(
!
par_op_
.
defined
())
{
arith
::
Analyzer
analyzer
;
par_op_
=
ParallelOp
(
MakeSIMTLoop
(
&
analyzer
));
}
if
(
T
.
layout_map
.
count
(
src
)
&&
T
.
layout_map
.
count
(
dst
))
{
if
(
T
.
layout_map
.
count
(
src
)
&&
T
.
layout_map
.
count
(
dst
))
{
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
)
{
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
)
{
const
FragmentNode
*
src_layout
=
T
.
layout_map
[
src
].
as
<
FragmentNode
>
();
const
FragmentNode
*
src_layout
=
T
.
layout_map
[
src
].
as
<
FragmentNode
>
();
...
@@ -342,7 +337,7 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
...
@@ -342,7 +337,7 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
}
}
}
}
}
}
return
par_op_
->
InferLayout
(
T
,
level
)
;
return
{}
;
}
}
/**
/**
...
...
src/op/builtin.cc
View file @
0ff4f427
...
@@ -295,5 +295,10 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
...
@@ -295,5 +295,10 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
atomicadd_elem_op
)
.
set_num_inputs
(
2
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/builtin.h
View file @
0ff4f427
...
@@ -501,6 +501,13 @@ TVM_DLL const Op &initialize_descriptor();
...
@@ -501,6 +501,13 @@ TVM_DLL const Op &initialize_descriptor();
* tilelang.
* tilelang.
*/
*/
TVM_DLL
const
Op
&
increase_descriptor_offset
();
TVM_DLL
const
Op
&
increase_descriptor_offset
();
/*!
* \brief tilelang intrinsic for element-wise atomic addition.
*
* This op is used to represent an element-wise atomic add operation in
* tilelang.
*/
TVM_DLL
const
Op
&
atomicadd_elem_op
();
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
...
...
src/transform/atomicadd_vectorize.cc
View file @
0ff4f427
...
@@ -23,25 +23,27 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
...
@@ -23,25 +23,27 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
PostOrderVisit
(
node
,
[
&
](
const
ObjectRef
&
obj
)
{
PostOrderVisit
(
node
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
call
=
obj
.
as
<
CallNode
>
())
{
if
(
const
auto
*
call
=
obj
.
as
<
CallNode
>
())
{
if
(
call
->
op
==
builtin
::
call_extern
()
&&
call
->
args
.
size
()
>=
2
)
{
if
(
call
->
op
==
atomicadd_elem_op
())
{
const
auto
*
func_name
=
call
->
args
[
0
].
as
<
StringImmNode
>
();
if
(
call
->
args
.
size
()
<
2
)
{
if
(
!
func_name
)
// Fallback: unexpected arity
vectorize_size_max
=
1
;
DLOG
(
WARNING
)
<<
"[AtomicAddVectorizePlanner] atomicadd_elem_op "
"expects 2 args, got "
<<
call
->
args
.
size
()
<<
"; Fallback to no vectorize"
;
return
;
return
;
if
(
func_name
->
value
==
"AtomicAdd"
)
{
}
DataType
dtype
;
DataType
dtype
;
if
(
const
auto
*
load
=
call
->
args
[
1
].
as
<
BufferLoadNode
>
())
{
if
(
const
auto
*
load
=
call
->
args
[
0
].
as
<
BufferLoadNode
>
())
{
dtype
=
load
->
dtype
;
dtype
=
load
->
dtype
;
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
}
else
if
(
const
auto
*
ite
=
call
->
args
[
1
].
as
<
IfThenElseNode
>
())
{
}
else
if
(
const
auto
*
ite
=
call
->
args
[
0
].
as
<
IfThenElseNode
>
())
{
if
(
const
auto
*
then_load
=
ite
->
then_case
.
as
<
BufferLoadNode
>
())
{
if
(
const
auto
*
then_load
=
ite
->
then_case
.
as
<
BufferLoadNode
>
())
{
dtype
=
then_load
->
dtype
;
dtype
=
then_load
->
dtype
;
vectorize_size_max
=
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
}
else
if
(
const
auto
*
else_load
=
}
else
if
(
const
auto
*
else_load
=
ite
->
else_case
.
as
<
BufferLoadNode
>
())
{
ite
->
else_case
.
as
<
BufferLoadNode
>
())
{
dtype
=
else_load
->
dtype
;
dtype
=
else_load
->
dtype
;
vectorize_size_max
=
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
}
else
{
}
else
{
// fallback
// fallback
vectorize_size_max
=
1
;
vectorize_size_max
=
1
;
...
@@ -57,7 +59,6 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
...
@@ -57,7 +59,6 @@ AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) {
}
}
}
}
}
}
}
});
});
if
(
vectorize_size_max
<=
1
)
{
if
(
vectorize_size_max
<=
1
)
{
...
@@ -75,13 +76,12 @@ void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
...
@@ -75,13 +76,12 @@ void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) {
}
}
void
AtomicAddVectorizePlanner
::
VisitExpr_
(
const
CallNode
*
node
)
{
void
AtomicAddVectorizePlanner
::
VisitExpr_
(
const
CallNode
*
node
)
{
if
(
node
->
op
==
builtin
::
call_extern
()
&&
node
->
args
.
size
()
>=
2
)
{
if
(
node
->
op
==
atomicadd_elem_op
()
&&
!
node
->
args
.
empty
())
{
if
(
const
auto
*
func_name
=
node
->
args
[
0
].
as
<
StringImmNode
>
())
{
if
(
node
->
args
.
size
()
<
2
)
{
if
(
func_name
->
value
==
"AtomicAdd"
)
{
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
const
BufferLoadNode
*
buffer_load_dst
=
}
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
buffer_load_dst
=
node
->
args
[
0
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
buffer_load_src
=
const
BufferLoadNode
*
buffer_load_src
=
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
node
->
args
[
2
].
as
<
BufferLoadNode
>
();
if
(
buffer_load_src
&&
buffer_load_src
->
buffer
.
defined
()
&&
if
(
buffer_load_src
&&
buffer_load_src
->
buffer
.
defined
()
&&
buffer_load_dst
&&
buffer_load_dst
->
buffer
.
defined
())
{
buffer_load_dst
&&
buffer_load_dst
->
buffer
.
defined
())
{
Buffer
dst_buffer
=
buffer_load_dst
->
buffer
;
Buffer
dst_buffer
=
buffer_load_dst
->
buffer
;
...
@@ -91,8 +91,6 @@ void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
...
@@ -91,8 +91,6 @@ void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) {
UpdateVectorSize
(
buffer_load_src
->
indices
,
src_buffer
);
UpdateVectorSize
(
buffer_load_src
->
indices
,
src_buffer
);
}
}
}
}
}
}
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
}
...
@@ -188,6 +186,8 @@ private:
...
@@ -188,6 +186,8 @@ private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
inner_for_
=
node
;
inner_for_
=
node
;
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
if
(
vector_size_
==
1
)
return
ret
;
if
(
inner_for_
==
node
)
{
if
(
inner_for_
==
node
)
{
For
fnode
=
ret
.
as
<
For
>
().
value
();
For
fnode
=
ret
.
as
<
For
>
().
value
();
auto
old_var
=
fnode
->
loop_var
;
auto
old_var
=
fnode
->
loop_var
;
...
@@ -210,48 +210,55 @@ private:
...
@@ -210,48 +210,55 @@ private:
}
}
PrimExpr
VisitExpr_
(
const
CallNode
*
node
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
node
)
final
{
if
(
dynamic_
)
{
bool
legal_vectorize
=
true
;
return
StmtExprMutator
::
VisitExpr_
(
node
);
if
(
dynamic_
)
}
legal_vectorize
=
false
;
if
(
vector_size_
==
2
||
vector_size_
==
4
)
{
if
(
!
(
node
->
op
==
atomicadd_elem_op
()))
if
(
node
->
op
==
builtin
::
call_extern
()
&&
node
->
args
.
size
()
>=
2
)
{
legal_vectorize
=
false
;
if
(
const
auto
*
func_name
=
node
->
args
[
0
].
as
<
StringImmNode
>
())
{
if
(
node
->
args
.
size
()
<
2
)
if
(
func_name
->
value
==
"AtomicAdd"
)
{
legal_vectorize
=
false
;
const
BufferLoadNode
*
temp_dst_node
=
if
(
legal_vectorize
)
{
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
temp_dst_node
=
node
->
args
[
0
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
temp_value_node
=
const
BufferLoadNode
*
temp_value_node
=
node
->
args
[
2
].
as
<
BufferLoadNode
>
();
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
if
(
!
temp_dst_node
||
!
temp_value_node
)
{
if
(
!
temp_dst_node
||
!
temp_value_node
)
return
StmtExprMutator
::
VisitExpr_
(
node
)
;
legal_vectorize
=
false
;
}
}
const
BufferLoad
dst_node
=
if
(
legal_vectorize
)
{
Downcast
<
BufferLoad
>
(
node
->
args
[
1
].
as
<
BufferLoadNode
>
());
const
BufferLoad
dst_node
=
Downcast
<
BufferLoad
>
(
node
->
args
[
0
]);
const
BufferLoad
value_node
=
const
BufferLoad
value_node
=
Downcast
<
BufferLoad
>
(
node
->
args
[
1
]);
Downcast
<
BufferLoad
>
(
node
->
args
[
2
].
as
<
BufferLoadNode
>
());
Call
address_of_dst
=
Call
address_of_dst
=
Call
(
DataType
::
Handle
(),
builtin
::
address_of
(),
{
dst_node
});
Call
(
DataType
::
Handle
(),
builtin
::
address_of
(),
{
dst_node
});
Call
address_of_value
=
Call
address_of_value
=
Call
(
DataType
::
Handle
(),
builtin
::
address_of
(),
{
value_node
});
Call
(
DataType
::
Handle
(),
builtin
::
address_of
(),
{
value_node
});
Array
<
PrimExpr
>
new_args
;
Array
<
PrimExpr
>
new_args
;
if
(
vector_size_
==
2
)
{
if
(
vector_size_
==
4
)
{
new_args
.
push_back
(
StringImm
(
"AtomicAddx4"
));
}
else
if
(
vector_size_
==
2
)
{
new_args
.
push_back
(
StringImm
(
"AtomicAddx2"
));
new_args
.
push_back
(
StringImm
(
"AtomicAddx2"
));
}
else
{
}
else
{
new_args
.
push_back
(
StringImm
(
"AtomicAdd
x4
"
));
new_args
.
push_back
(
StringImm
(
"AtomicAdd"
));
}
}
new_args
.
push_back
(
address_of_dst
);
new_args
.
push_back
(
address_of_dst
);
new_args
.
push_back
(
address_of_value
);
new_args
.
push_back
(
address_of_value
);
Call
new_call
=
tvm
::
tir
::
Call
(
node
->
dtype
,
builtin
::
call_extern
(),
new_args
);
return
new_call
;
}
else
{
Array
<
PrimExpr
>
new_args
;
new_args
.
push_back
(
StringImm
(
"AtomicAdd"
));
for
(
auto
x
:
node
->
args
)
new_args
.
push_back
(
x
);
Call
new_call
=
Call
new_call
=
tvm
::
tir
::
Call
(
node
->
dtype
,
builtin
::
call_extern
(),
new_args
);
tvm
::
tir
::
Call
(
node
->
dtype
,
builtin
::
call_extern
(),
new_args
);
return
new_call
;
return
new_call
;
}
}
}
}
}
}
return
StmtExprMutator
::
VisitExpr_
(
node
);
}
const
ForNode
*
inner_for_
;
const
ForNode
*
inner_for_
;
const
int
vector_size_
;
const
int
vector_size_
;
...
@@ -263,9 +270,6 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
...
@@ -263,9 +270,6 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
AtomicAddVectorizePlanResult
res
=
{
1
,
false
,
0
};
AtomicAddVectorizePlanResult
res
=
{
1
,
false
,
0
};
AtomicAddVectorizePlanner
planner
;
AtomicAddVectorizePlanner
planner
;
res
=
planner
.
Plan
(
for_node
,
compute_capability
);
res
=
planner
.
Plan
(
for_node
,
compute_capability
);
int
vectorize_hint
=
res
.
vector_size
;
if
(
vectorize_hint
==
1
)
return
for_node
;
auto
rewriter
=
AtomicAddVectorizeRewriter
(
res
);
auto
rewriter
=
AtomicAddVectorizeRewriter
(
res
);
return
Downcast
<
For
>
(
rewriter
(
for_node
));
return
Downcast
<
For
>
(
rewriter
(
for_node
));
}
}
...
...
src/transform/atomicadd_vectorize.h
View file @
0ff4f427
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "../layout/layout.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../layout/utils.h"
#include "../op/builtin.h"
#include "arith/int_operator.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "atomicadd_vectorize.h"
#include "atomicadd_vectorize.h"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment