Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
340bfc50
Unverified
Commit
340bfc50
authored
Oct 13, 2025
by
Yuqi Dong
Committed by
GitHub
Oct 13, 2025
Browse files
[Bugfix] Fix atomicadd auto vectorize identify var error (#883)
* update * update * update * update
parent
4a229ddb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
355 additions
and
331 deletions
+355
-331
src/op/atomic_add.cc
src/op/atomic_add.cc
+173
-84
src/transform/atomicadd_vectorize.cc
src/transform/atomicadd_vectorize.cc
+143
-245
src/transform/atomicadd_vectorize.h
src/transform/atomicadd_vectorize.h
+39
-2
No files found.
src/op/atomic_add.cc
View file @
340bfc50
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#include "../target/utils.h"
#include "../target/utils.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/atomicadd_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_parallel_transform_utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_partition.h"
#include "builtin.h"
#include "builtin.h"
...
@@ -21,31 +22,6 @@ namespace tl {
...
@@ -21,31 +22,6 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
/**
* @brief Extracts a numeric architecture identifier from a Target's "arch"
* attribute.
*
* Reads the Target's "arch" string (must be defined) and, if it has the form
* "sm_<N>", parses and returns N as an integer. For any other arch string,
* returns 0.
*
* @param target Target whose "arch" attribute will be inspected (ICHECKs that
* the attribute is defined).
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
*/
static
int
GetArchInt
(
Target
target
)
{
int
arch_int
=
0
;
auto
s
=
target
->
GetAttr
<
String
>
(
"arch"
);
ICHECK
(
s
.
defined
());
std
::
string
arch
=
s
.
value
();
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
{
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
}
else
{
arch_int
=
0
;
}
return
arch_int
;
}
/**
/**
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
*
*
...
@@ -328,6 +304,47 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
...
@@ -328,6 +304,47 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return
Downcast
<
For
>
(
body
);
return
Downcast
<
For
>
(
body
);
}
}
/**
* @brief Infer and return the layout map for the atomic add operator.
*
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap
AtomicAddNode
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
{
if
(
!
par_op_
.
defined
())
{
arith
::
Analyzer
analyzer
;
par_op_
=
ParallelOp
(
MakeSIMTLoop
(
&
analyzer
));
}
if
(
T
.
layout_map
.
count
(
src
)
&&
T
.
layout_map
.
count
(
dst
))
{
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
)
{
const
FragmentNode
*
src_layout
=
T
.
layout_map
[
src
].
as
<
FragmentNode
>
();
const
FragmentNode
*
dst_layout
=
T
.
layout_map
[
dst
].
as
<
FragmentNode
>
();
if
(
src_layout
&&
dst_layout
)
{
ICHECK
(
src_layout
->
IsEqual
(
dst_layout
,
true
))
<<
"Get different layout for "
<<
src
<<
" and "
<<
dst
<<
"
\n
LHS = "
<<
src_layout
->
DebugOutput
()
<<
"
\n
RHS = "
<<
dst_layout
->
DebugOutput
()
<<
"
\n
You may need to use a shared memory to transform the layout"
;
}
}
}
return
par_op_
->
InferLayout
(
T
,
level
);
}
/**
/**
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
* TIR loop.
* TIR loop.
...
@@ -389,70 +406,142 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
...
@@ -389,70 +406,142 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}
auto
simt_loop
=
MakeSIMTLoop
(
analyzer
);
auto
simt_loop
=
MakeSIMTLoop
(
analyzer
);
auto
fused_loop
=
Downcast
<
For
>
(
ParallelLoopFuser
::
Fuse
(
simt_loop
));
auto
fused_loop
=
Downcast
<
For
>
(
ParallelLoopFuser
::
Fuse
(
simt_loop
));
auto
par_op
=
ParallelOp
(
fused_loop
);
auto
transformed_loop
=
Downcast
<
For
>
(
ParallelLoopTransformer
::
Substitute
(
fused_loop
));
std
::
vector
<
InferLevel
>
levels
=
{
InferLevel
::
kCommon
,
InferLevel
::
kStrict
,
InferLevel
::
kFree
};
auto
GetArchInt
=
[
&
](
const
Target
&
tgt
)
->
int
{
for
(
auto
level
:
levels
)
{
int
arch_int
=
0
;
(
par_op
)
->
InferLayout
({
T
.
target
,
T
.
thread_bounds
,
T
.
layout_map
,
analyzer
,
if
(
auto
s
=
tgt
->
GetAttr
<
String
>
(
"arch"
))
{
false
,
T
.
buffer_remap
},
std
::
string
arch
=
s
.
value
();
level
);
if
(
arch
.
rfind
(
"sm_"
,
0
)
==
0
)
}
arch_int
=
std
::
stoi
(
arch
.
substr
(
3
));
auto
loop_layout
=
par_op
->
GetLoopLayout
();
}
Var
thread_var
=
T
.
thread_var
;
return
arch_int
;
Range
thread_bounds
=
T
.
thread_bounds
;
};
auto
thread_loop
=
PartitionLoop
(
par_op
->
GetRoot
(),
T
.
thread_var
,
analyzer
,
loop_layout
);
auto
vectorized_thread_loop
=
VectorizeAtomicAdd
(
thread_loop
,
thread_var
,
thread_bounds
,
GetArchInt
(
target
));
if
(
par_op
->
GetPredicate
(
T
.
thread_var
).
defined
())
{
struct
AtomicLoopNestCollector
:
tir
::
StmtExprVisitor
{
return
IfThenElse
(
par_op
->
GetPredicate
(
T
.
thread_var
).
value
(),
Array
<
IterVar
>
loop_vars
;
vectorized_thread_loop
);
Map
<
Buffer
,
Array
<
PrimExpr
>>
indice_map
;
}
std
::
unordered_set
<
Buffer
,
ObjectPtrHash
,
ObjectPtrEqual
>
writes
;
arith
::
Analyzer
analyzer
;
return
vectorized_thread_loop
;
void
Run
(
const
Stmt
&
s
)
{
StmtExprVisitor
::
VisitStmt
(
s
);
}
}
/**
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
* @brief Infer and return the layout map for the atomic add operator.
if
(
op
->
kind
==
ForKind
::
kParallel
)
{
*
loop_vars
.
push_back
(
IterVar
(
Range
(
op
->
min
,
op
->
extent
),
op
->
loop_var
,
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
IterVarType
::
kDataPar
));
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap
AtomicAddNode
::
InferLayout
(
const
LayoutInferArgs
&
T
,
InferLevel
level
)
const
{
if
(
!
par_op_
.
defined
())
{
arith
::
Analyzer
analyzer
;
par_op_
=
ParallelOp
(
MakeSIMTLoop
(
&
analyzer
));
}
if
(
T
.
layout_map
.
count
(
src
)
&&
T
.
layout_map
.
count
(
dst
))
{
if
(
src
.
scope
()
==
"local.fragment"
&&
dst
.
scope
()
==
"local.fragment"
)
{
const
FragmentNode
*
src_layout
=
T
.
layout_map
[
src
].
as
<
FragmentNode
>
();
const
FragmentNode
*
dst_layout
=
T
.
layout_map
[
dst
].
as
<
FragmentNode
>
();
if
(
src_layout
&&
dst_layout
)
{
ICHECK
(
src_layout
->
IsEqual
(
dst_layout
,
true
))
<<
"Get different layout for "
<<
src
<<
" and "
<<
dst
<<
"
\n
LHS = "
<<
src_layout
->
DebugOutput
()
<<
"
\n
RHS = "
<<
dst_layout
->
DebugOutput
()
<<
"
\n
You may need to use a shared memory to transform the layout"
;
}
}
analyzer
.
Bind
(
op
->
loop_var
,
Range
::
FromMinExtent
(
op
->
min
,
op
->
extent
));
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
return
par_op_
->
InferLayout
(
T
,
level
);
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
indice_map
.
Set
(
op
->
buffer
,
op
->
indices
);
writes
.
insert
(
op
->
buffer
);
}
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
if
(
op
->
buffer
.
scope
()
==
"local.fragment"
)
{
indice_map
.
Set
(
op
->
buffer
,
op
->
indices
);
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
};
auto
ComputeLoopLayoutFromBuffer
=
[
&
](
const
Buffer
&
buf
,
const
Array
<
PrimExpr
>
&
indices
,
const
LayoutMap
&
layout_map
,
const
Range
&
thread_bounds
,
const
Array
<
IterVar
>
&
loop_vars
)
->
Fragment
{
Fragment
src
=
layout_map
[
buf
].
as
<
Fragment
>
().
value
();
Var
rep
;
auto
rep_iter
=
IterVar
(
Range
(
0
,
src
->
ReplicateExtent
()),
rep
,
IterVarType
::
kDataPar
);
PrimExpr
fth
=
src
->
ForwardThread
(
indices
,
rep
);
fth
=
analyzer
->
Simplify
(
fth
);
Fragment
out
=
Fragment
(
loop_vars
,
/*forward_index=*/
{},
fth
,
rep_iter
)
->
BindThreadRange
(
thread_bounds
);
return
out
;
};
struct
AtomicInferResult
{
Fragment
loop_layout
;
Optional
<
PrimExpr
>
predicate
;
};
auto
AtomicAddInferLayout
=
[
&
](
const
For
&
loop
,
const
LayoutInferArgs
&
args
)
->
AtomicInferResult
{
AtomicLoopNestCollector
C
;
C
.
Run
(
loop
);
Optional
<
Buffer
>
read_src
;
int
best_rank
=
-
1
;
for
(
auto
kv
:
C
.
indice_map
)
{
const
Buffer
&
buf
=
kv
.
first
;
if
(
buf
.
scope
()
!=
"local.fragment"
)
continue
;
if
(
!
args
.
layout_map
.
count
(
buf
))
continue
;
int
rank
=
static_cast
<
int
>
(
kv
.
second
.
size
());
if
(
rank
>
best_rank
)
{
best_rank
=
rank
;
read_src
=
buf
;
}
}
AtomicAddVectorizePlanner
planner
;
int
sm
=
GetArchInt
(
target
);
auto
plan
=
planner
.
Plan
(
loop
,
sm
);
int
vec
=
std
::
max
(
plan
.
vector_size
,
1
);
if
(
auto
cw
=
loop
->
annotations
.
Get
(
"coalesced_width"
))
{
if
(
const
auto
*
imm
=
cw
->
as
<
IntImmNode
>
())
{
int
expected
=
imm
->
value
;
ICHECK_GT
(
expected
,
0
);
ICHECK
(
vec
%
expected
==
0
)
<<
"vector_size "
<<
vec
<<
" not divisible by coalesced_width "
<<
expected
;
vec
=
expected
;
}
else
{
LOG
(
FATAL
)
<<
"coalesced_width should be IntImmNode."
;
}
}
PrimExpr
total
=
1
;
for
(
Stmt
s
=
loop
;
s
.
as
<
For
>
().
has_value
();
s
=
s
.
as
<
For
>
().
value
()
->
body
)
total
=
total
*
s
.
as
<
For
>
().
value
()
->
extent
;
PrimExpr
denom
=
args
.
thread_bounds
->
extent
*
vec
;
while
(
!
analyzer
->
CanProve
(
floormod
(
total
,
denom
)
==
0
)
&&
vec
>
1
)
{
vec
>>=
1
;
denom
=
args
.
thread_bounds
->
extent
*
vec
;
}
if
(
vec
<
1
)
vec
=
1
;
Fragment
loop_layout
;
if
(
read_src
)
{
loop_layout
=
ComputeLoopLayoutFromBuffer
(
read_src
.
value
(),
C
.
indice_map
[
read_src
.
value
()],
args
.
layout_map
,
args
.
thread_bounds
,
C
.
loop_vars
);
}
else
{
const
For
&
remapped
=
loop
;
loop_layout
=
PlanLoopPartition
(
remapped
,
vec
,
args
.
thread_bounds
);
}
Optional
<
PrimExpr
>
pred
;
if
(
plan
.
dynamic
&&
plan
.
condition
.
defined
())
{
pred
=
plan
.
condition
;
}
DLOG
(
INFO
)
<<
"[AtomicAddInferLayout] vec="
<<
vec
<<
" loop_layout="
<<
loop_layout
->
DebugOutput
();
return
{
loop_layout
,
pred
};
};
auto
ret
=
AtomicAddInferLayout
(
transformed_loop
,
{
T
.
target
,
T
.
thread_bounds
,
T
.
layout_map
,
analyzer
,
false
,
T
.
buffer_remap
});
Fragment
loop_layout
=
ret
.
loop_layout
;
auto
thread_loop
=
PartitionLoop
(
transformed_loop
,
T
.
thread_var
,
analyzer
,
loop_layout
);
auto
vectorized_thread_loop
=
VectorizeAtomicAdd
(
thread_loop
,
GetArchInt
(
target
));
return
vectorized_thread_loop
;
}
}
TIR_REGISTER_TL_OP
(
AtomicAdd
,
atomicadd
)
TIR_REGISTER_TL_OP
(
AtomicAdd
,
atomicadd
)
...
...
src/transform/atomicadd_vectorize.cc
View file @
340bfc50
...
@@ -3,18 +3,7 @@
...
@@ -3,18 +3,7 @@
* \brief A tool to automatically vectorize atomic add
* \brief A tool to automatically vectorize atomic add
*/
*/
#include "../layout/layout.h"
#include "atomicadd_vectorize.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
...
@@ -23,132 +12,151 @@ using namespace tir;
...
@@ -23,132 +12,151 @@ using namespace tir;
using
arith
::
IRMutatorWithAnalyzer
;
using
arith
::
IRMutatorWithAnalyzer
;
using
arith
::
IRVisitorWithAnalyzer
;
using
arith
::
IRVisitorWithAnalyzer
;
struct
AtomicAddVectorizePlanResult
{
AtomicAddVectorizePlanner
::
AtomicAddVectorizePlanner
()
=
default
;
int
vector_size
;
bool
dynamic
;
PrimExpr
condition
;
};
class
AtomicAddVectorizePlanner
:
public
arith
::
IRVisitorWithAnalyzer
{
AtomicAddVectorizePlanResult
public:
AtomicAddVectorizePlanner
::
Plan
(
const
For
&
node
,
int
compute_capability
)
{
AtomicAddVectorizePlanner
()
=
default
;
int
vectorize_size_max
=
1
;
int
max_vector_size
=
1
;
this
->
vector_size_
=
4
;
AtomicAddVectorizePlanResult
Plan
(
const
For
&
node
,
Var
thread_var
,
this
->
dynamic_
=
false
;
Range
thread_bounds
,
int
vectorize_hint
)
{
this
->
condition_
=
PrimExpr
();
this
->
max_vector_size
=
vectorize_hint
;
this
->
thread_var
=
std
::
move
(
thread_var
);
this
->
thread_bounds
=
std
::
move
(
thread_bounds
);
this
->
operator
()(
node
);
return
{
vector_size_
,
dynamic_
,
condition_
};
}
private:
PostOrderVisit
(
node
,
[
&
](
const
ObjectRef
&
obj
)
{
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
if
(
const
auto
*
call
=
obj
.
as
<
CallNode
>
())
{
inner_for_
=
node
;
if
(
call
->
op
==
builtin
::
call_extern
()
&&
call
->
args
.
size
()
>=
2
)
{
iter_map_
.
Set
(
node
->
loop_var
,
Range
(
node
->
min
,
node
->
extent
));
const
auto
*
func_name
=
call
->
args
[
0
].
as
<
StringImmNode
>
();
if
(
!
func_name
)
return
;
if
(
func_name
->
value
==
"AtomicAdd"
)
{
DataType
dtype
;
if
(
const
auto
*
load
=
call
->
args
[
1
].
as
<
BufferLoadNode
>
())
{
dtype
=
load
->
dtype
;
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
}
else
if
(
const
auto
*
ite
=
call
->
args
[
1
].
as
<
IfThenElseNode
>
())
{
if
(
const
auto
*
then_load
=
ite
->
then_case
.
as
<
BufferLoadNode
>
())
{
dtype
=
then_load
->
dtype
;
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
}
else
if
(
const
auto
*
else_load
=
ite
->
else_case
.
as
<
BufferLoadNode
>
())
{
dtype
=
else_load
->
dtype
;
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
}
else
{
// fallback
vectorize_size_max
=
1
;
DLOG
(
WARNING
)
<<
"[AtomicAddVectorizePlanner] IfThenElse case "
"has no BufferLoad; Fallback to no vectorize"
;
}
}
else
{
// fallback
vectorize_size_max
=
1
;
DLOG
(
WARNING
)
<<
"[AtomicAddVectorizePlanner] Unexpected arg1 type "
<<
call
->
args
[
1
]
->
GetTypeKey
()
<<
"; Fallback to no vectorize"
;
}
}
}
}
});
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
if
(
vectorize_size_max
<=
1
)
{
return
{
1
,
dynamic_
,
condition_
};
}
}
void
VisitExpr_
(
const
CallNode
*
node
)
final
{
this
->
max_vector_size
=
vectorize_size_max
;
if
(
node
->
op
==
builtin
::
call_extern
()
&&
node
->
args
.
size
()
>=
2
)
{
this
->
operator
()(
node
);
if
(
const
auto
*
func_name
=
node
->
args
[
0
].
as
<
StringImmNode
>
())
{
return
{
vector_size_
,
dynamic_
,
condition_
};
if
(
func_name
->
value
==
"AtomicAdd"
)
{
}
const
BufferLoadNode
*
buffer_load_dst
=
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
buffer_load_src
=
node
->
args
[
2
].
as
<
BufferLoadNode
>
();
if
(
buffer_load_src
&&
buffer_load_src
->
buffer
.
defined
()
&&
buffer_load_dst
&&
buffer_load_dst
->
buffer
.
defined
())
{
Buffer
dst_buffer
=
buffer_load_dst
->
buffer
;
void
AtomicAddVectorizePlanner
::
VisitStmt_
(
const
ForNode
*
node
)
{
Array
<
PrimExpr
>
indices_dst
=
buffer_load_dst
->
indices
;
inner_for_
=
node
;
UpdateVectorSize
(
indices_dst
,
dst_buffer
);
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
Buffer
src_buffer
=
buffer_load_src
->
buffer
;
}
Array
<
PrimExpr
>
indices_src
=
buffer_load_src
->
indices
;
UpdateVectorSize
(
indices_src
,
src_buffer
);
void
AtomicAddVectorizePlanner
::
VisitExpr_
(
const
CallNode
*
node
)
{
}
if
(
node
->
op
==
builtin
::
call_extern
()
&&
node
->
args
.
size
()
>=
2
)
{
if
(
const
auto
*
func_name
=
node
->
args
[
0
].
as
<
StringImmNode
>
())
{
if
(
func_name
->
value
==
"AtomicAdd"
)
{
const
BufferLoadNode
*
buffer_load_dst
=
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
buffer_load_src
=
node
->
args
[
2
].
as
<
BufferLoadNode
>
();
if
(
buffer_load_src
&&
buffer_load_src
->
buffer
.
defined
()
&&
buffer_load_dst
&&
buffer_load_dst
->
buffer
.
defined
())
{
Buffer
dst_buffer
=
buffer_load_dst
->
buffer
;
UpdateVectorSize
(
buffer_load_dst
->
indices
,
dst_buffer
);
Buffer
src_buffer
=
buffer_load_src
->
buffer
;
UpdateVectorSize
(
buffer_load_src
->
indices
,
src_buffer
);
}
}
}
}
}
}
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
}
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
void
UpdateVectorSize
(
const
Array
<
PrimExpr
>
&
indices
,
const
Buffer
&
buffer
)
{
int
AtomicAddVectorizePlanner
::
GetVectorizeSizeMax
(
int
compute_capability
,
if
(
!
inner_for_
)
DataType
dtype
)
{
return
;
if
(
dtype
==
DataType
::
Float
(
16
))
{
auto
extent_ptr
=
inner_for_
->
extent
.
as
<
IntImmNode
>
();
return
2
;
if
(
!
extent_ptr
)
}
return
;
if
(
dtype
==
DataType
::
BFloat
(
16
))
{
return
compute_capability
>
75
?
2
:
1
;
}
if
(
dtype
==
DataType
::
Float
(
32
))
{
return
compute_capability
>=
90
?
4
:
1
;
}
return
1
;
}
const
DataType
&
access_type
=
buffer
->
dtype
;
void
AtomicAddVectorizePlanner
::
UpdateVectorSize
(
const
Array
<
PrimExpr
>
&
indices
,
// i // 2, i % 8 can also be vectorized as factor 16
const
Buffer
&
buffer
)
{
// so we should disable this GCD optimization
if
(
!
inner_for_
)
return
;
auto
extent_ptr
=
inner_for_
->
extent
.
as
<
IntImmNode
>
();
if
(
!
extent_ptr
)
return
;
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
extent_ptr
->
value
);
const
DataType
&
access_type
=
buffer
->
dtype
;
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
extent_ptr
->
value
);
auto
last_dim
=
buffer
->
shape
.
back
();
auto
last_dim
=
buffer
->
shape
.
back
();
auto
mod_set
=
analyzer_
.
modular_set
(
last_dim
);
auto
mod_set
=
analyzer_
.
modular_set
(
last_dim
);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if
(
buffer
->
shape
.
back
().
as
<
IntImmNode
>
())
{
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
coeff
);
if
(
buffer
->
shape
.
back
().
as
<
IntImmNode
>
())
{
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
coeff
);
auto
gcd_base
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
base
);
auto
gcd_base
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
base
);
if
(
gcd_base
<
Downcast
<
IntImm
>
(
last_dim
)
->
value
)
{
// If gcd_base is equal to the last dimension,
max_vector_size
=
gcd_base
;
// we should analyze the second-to-last dimension
}
// in relation to the last dimension.
if
(
gcd_base
<
Downcast
<
IntImm
>
(
last_dim
)
->
value
)
{
max_vector_size
=
gcd_base
;
}
vector_size_
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
vector_size_
);
vector_size_
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
vector_size_
);
PrimExpr
elem_offset
=
0
;
PrimExpr
elem_offset
=
0
;
PrimExpr
stride
=
1
;
PrimExpr
stride
=
1
;
for
(
int
i
=
indices
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
indices
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
elem_offset
=
elem_offset
+
indices
[
i
]
*
stride
;
elem_offset
=
elem_offset
+
indices
[
i
]
*
stride
;
stride
=
stride
*
buffer
->
shape
[
i
];
stride
=
stride
*
buffer
->
shape
[
i
];
}
PrimExpr
thread_extent
=
thread_bounds
->
extent
;
while
(
!
IndiceCanVectorize
(
elem_offset
,
thread_var
,
thread_extent
,
vector_size_
,
&
analyzer_
))
{
vector_size_
/=
2
;
}
}
else
if
(
vector_size_
<=
4
)
{
// dynamic shape load: get the vectorization condition
dynamic_
=
true
;
PrimExpr
offset
=
buffer
.
OffsetOf
(
indices
).
back
();
condition_
=
(
truncmod
(
offset
,
vector_size_
)
==
0
);
}
}
}
const
ForNode
*
inner_for_
;
while
(
!
IndiceCanVectorize
(
elem_offset
,
inner_for_
->
loop_var
,
Map
<
Var
,
Range
>
iter_map_
;
inner_for_
->
extent
,
vector_size_
,
&
analyzer_
))
{
bool
has_nonlocal_memory_access_
=
false
;
vector_size_
/=
2
;
int
vector_size_
=
4
;
}
Var
thread_var
;
}
else
if
(
vector_size_
<=
4
)
{
Range
thread_bounds
;
dynamic_
=
true
;
bool
dynamic_
=
false
;
PrimExpr
offset
=
buffer
.
OffsetOf
(
indices
).
back
();
PrimExpr
condition_
;
condition_
=
(
truncmod
(
offset
,
vector_size_
)
==
0
);
};
}
}
class
AtomicAddVectorizeRewriter
:
public
StmtExprMutator
{
class
AtomicAddVectorizeRewriter
:
public
StmtExprMutator
{
public:
public:
AtomicAddVectorizeRewriter
(
const
AtomicAddVectorizePlanResult
&
plan
,
AtomicAddVectorizeRewriter
(
const
AtomicAddVectorizePlanResult
&
plan
)
Var
thread_var
,
PrimExpr
by_var
,
PrimExpr
bx_var
,
:
vector_size_
(
plan
.
vector_size
),
dynamic_
(
plan
.
dynamic
),
const
Range
&
thread_bounds
,
int
stride_y
,
condition_
(
plan
.
condition
)
{}
int
stride_x
)
:
vector_size_
(
plan
.
vector_size
),
condition_
(
plan
.
condition
),
dynamic_
(
plan
.
dynamic
),
tx_var_
(
std
::
move
(
thread_var
)),
by_var_
(
std
::
move
(
by_var
)),
bx_var_
(
std
::
move
(
bx_var
)),
stride_y_
(
stride_y
),
stride_x_
(
stride_x
)
{
const
int64_t
*
tx_ext
=
as_const_int
(
thread_bounds
->
extent
);
ICHECK
(
tx_ext
)
<<
"thread_bounds->extent must be a constant for vectorization."
;
extent_tx_
=
static_cast
<
int
>
(
*
tx_ext
);
}
private:
private:
/**
/**
...
@@ -179,10 +187,11 @@ private:
...
@@ -179,10 +187,11 @@ private:
*/
*/
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
inner_for_
=
node
;
inner_for_
=
node
;
iter_var_
=
Var
(
node
->
loop_var
->
name_hint
+
"_outer"
);
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
if
(
inner_for_
==
node
)
{
// rewrite the innermost loop
if
(
inner_for_
==
node
)
{
For
fnode
=
ret
.
as
<
For
>
().
value
();
For
fnode
=
ret
.
as
<
For
>
().
value
();
auto
old_var
=
fnode
->
loop_var
;
auto
new_var
=
Var
(
old_var
->
name_hint
);
auto
extent_ptr
=
as_const_int
(
fnode
->
extent
);
auto
extent_ptr
=
as_const_int
(
fnode
->
extent
);
ICHECK
(
extent_ptr
)
<<
fnode
->
extent
;
ICHECK
(
extent_ptr
)
<<
fnode
->
extent
;
int
extent
=
*
extent_ptr
;
int
extent
=
*
extent_ptr
;
...
@@ -191,9 +200,9 @@ private:
...
@@ -191,9 +200,9 @@ private:
ICHECK
(
is_zero
(
fnode
->
min
));
ICHECK
(
is_zero
(
fnode
->
min
));
if
(
!
dynamic_
)
{
if
(
!
dynamic_
)
{
Map
<
Var
,
PrimExpr
>
vmap
;
Map
<
Var
,
PrimExpr
>
vmap
;
vmap
.
Set
(
fnode
->
loop_var
,
iter_var
_
);
vmap
.
Set
(
old_var
,
new_var
*
vector_size
_
);
Stmt
body
=
Substitute
(
fnode
->
body
,
vmap
);
Stmt
body
=
Substitute
(
fnode
->
body
,
vmap
);
return
For
(
iter
_var
_
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
return
For
(
new
_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
}
}
}
}
...
@@ -208,57 +217,18 @@ private:
...
@@ -208,57 +217,18 @@ private:
if
(
node
->
op
==
builtin
::
call_extern
()
&&
node
->
args
.
size
()
>=
2
)
{
if
(
node
->
op
==
builtin
::
call_extern
()
&&
node
->
args
.
size
()
>=
2
)
{
if
(
const
auto
*
func_name
=
node
->
args
[
0
].
as
<
StringImmNode
>
())
{
if
(
const
auto
*
func_name
=
node
->
args
[
0
].
as
<
StringImmNode
>
())
{
if
(
func_name
->
value
==
"AtomicAdd"
)
{
if
(
func_name
->
value
==
"AtomicAdd"
)
{
// Matrix[by * stride_y + i / (stride_x / (tx_txtent *
const
BufferLoadNode
*
temp_dst_node
=
// vector_size_)) + tx_var_ / (stride_x / vector_size_),
// bx * stride_x + (i % (stride_x / (tx_extent *
// vector_size_)) * (tx_extent * vector_size_) + (tx_var_ %
// (stride / vector_size_)) * vector_size_]
const
BufferLoadNode
*
old_dst_node
=
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
node
->
args
[
1
].
as
<
BufferLoadNode
>
();
const
BufferLoadNode
*
old
_value_node
=
const
BufferLoadNode
*
temp
_value_node
=
node
->
args
[
2
].
as
<
BufferLoadNode
>
();
node
->
args
[
2
].
as
<
BufferLoadNode
>
();
if
(
!
old
_dst_node
||
!
old
_value_node
)
{
if
(
!
temp
_dst_node
||
!
temp
_value_node
)
{
return
StmtExprMutator
::
VisitExpr_
(
node
);
return
StmtExprMutator
::
VisitExpr_
(
node
);
}
}
Array
<
PrimExpr
>
dst_indices
,
value_indices
;
const
BufferLoad
dst_node
=
if
((
extent_tx_
*
vector_size_
)
>
stride_x_
)
{
Downcast
<
BufferLoad
>
(
node
->
args
[
1
].
as
<
BufferLoadNode
>
());
dst_indices
.
push_back
(
const
BufferLoad
value_node
=
by_var_
*
stride_y_
+
Downcast
<
BufferLoad
>
(
node
->
args
[
2
].
as
<
BufferLoadNode
>
());
iter_var_
*
(
extent_tx_
*
vector_size_
/
stride_x_
)
+
truncdiv
(
tx_var_
,
stride_x_
/
vector_size_
));
dst_indices
.
push_back
(
bx_var_
*
stride_x_
+
truncmod
(
tx_var_
,
stride_x_
/
vector_size_
)
*
vector_size_
);
value_indices
.
push_back
(
iter_var_
*
(
extent_tx_
*
vector_size_
/
stride_x_
)
+
truncdiv
(
tx_var_
*
vector_size_
,
stride_x_
));
value_indices
.
push_back
(
truncmod
(
tx_var_
,
stride_x_
/
vector_size_
)
*
vector_size_
);
}
else
{
dst_indices
.
push_back
(
by_var_
*
stride_y_
+
truncdiv
(
iter_var_
,
stride_x_
/
(
extent_tx_
*
vector_size_
))
+
truncdiv
(
tx_var_
,
stride_x_
/
vector_size_
));
dst_indices
.
push_back
(
bx_var_
*
stride_x_
+
truncmod
(
iter_var_
,
stride_x_
/
(
extent_tx_
*
vector_size_
))
*
(
extent_tx_
*
vector_size_
)
+
truncmod
(
tx_var_
,
stride_x_
/
vector_size_
)
*
vector_size_
);
value_indices
.
push_back
(
truncdiv
(
iter_var_
,
stride_x_
/
(
extent_tx_
*
vector_size_
))
+
truncdiv
(
tx_var_
,
stride_x_
/
vector_size_
));
value_indices
.
push_back
(
truncmod
(
iter_var_
,
stride_x_
/
(
extent_tx_
*
vector_size_
))
*
(
extent_tx_
*
vector_size_
)
+
truncmod
(
tx_var_
,
stride_x_
/
vector_size_
)
*
vector_size_
);
}
BufferLoad
dst_node
=
BufferLoad
(
old_dst_node
->
buffer
,
dst_indices
,
old_dst_node
->
predicate
,
old_dst_node
->
span
);
BufferLoad
value_node
=
BufferLoad
(
old_value_node
->
buffer
,
value_indices
,
old_value_node
->
predicate
,
old_value_node
->
span
);
Call
address_of_dst
=
Call
address_of_dst
=
Call
(
DataType
::
Handle
(),
builtin
::
address_of
(),
{
dst_node
});
Call
(
DataType
::
Handle
(),
builtin
::
address_of
(),
{
dst_node
});
Call
address_of_value
=
Call
address_of_value
=
...
@@ -287,89 +257,17 @@ private:
...
@@ -287,89 +257,17 @@ private:
const
int
vector_size_
;
const
int
vector_size_
;
const
PrimExpr
condition_
;
const
PrimExpr
condition_
;
const
bool
dynamic_
;
const
bool
dynamic_
;
const
PrimExpr
by_var_
,
bx_var_
;
int
stride_y_
,
stride_x_
;
const
Var
tx_var_
;
Var
iter_var_
;
int
extent_tx_
;
};
};
static
int
GetVectorizeSizeMax
(
int
compute_capability
,
DataType
dtype
)
{
For
VectorizeAtomicAdd
(
const
For
&
for_node
,
int
compute_capability
)
{
AtomicAddVectorizePlanResult
res
=
{
1
,
false
,
0
};
if
(
dtype
==
DataType
::
Float
(
16
))
{
AtomicAddVectorizePlanner
planner
;
return
2
;
res
=
planner
.
Plan
(
for_node
,
compute_capability
);
}
int
vectorize_hint
=
res
.
vector_size
;
if
(
dtype
==
DataType
::
BFloat
(
16
))
{
if
(
vectorize_hint
==
1
)
if
(
compute_capability
>
75
)
{
return
2
;
}
else
{
return
1
;
}
}
if
(
dtype
==
DataType
::
Float
(
32
))
{
if
(
compute_capability
>=
90
)
{
return
4
;
}
else
{
return
1
;
}
}
return
1
;
}
For
VectorizeAtomicAdd
(
const
For
&
for_node
,
const
Var
&
thread_var
,
const
Range
&
thread_bounds
,
int
compute_capability
)
{
int
vectorize_size_max
=
1
;
int
stride_x
=
-
1
,
stride_y
=
-
1
;
PrimExpr
bx_var
,
by_var
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
call
=
obj
.
as
<
CallNode
>
())
{
if
(
call
->
op
==
builtin
::
call_extern
()
&&
call
->
args
.
size
()
>=
2
)
{
const
auto
*
func_name
=
call
->
args
[
0
].
as
<
StringImmNode
>
();
if
(
func_name
->
value
==
"AtomicAdd"
)
{
DataType
dtype
=
call
->
args
[
1
].
as
<
BufferLoadNode
>
()
->
dtype
;
vectorize_size_max
=
GetVectorizeSizeMax
(
compute_capability
,
dtype
);
}
}
}
if
(
const
MulNode
*
mul
=
obj
.
as
<
MulNode
>
())
{
const
VarNode
*
var
=
nullptr
;
const
IntImmNode
*
imm
=
nullptr
;
PrimExpr
var_expr
;
if
((
var
=
mul
->
a
.
as
<
VarNode
>
())
&&
(
imm
=
mul
->
b
.
as
<
IntImmNode
>
()))
{
var_expr
=
mul
->
a
;
}
else
if
((
var
=
mul
->
b
.
as
<
VarNode
>
())
&&
(
imm
=
mul
->
a
.
as
<
IntImmNode
>
()))
{
var_expr
=
mul
->
b
;
}
if
(
var
&&
imm
)
{
if
(
var
->
name_hint
==
"bx"
)
{
stride_x
=
imm
->
value
;
bx_var
=
var_expr
;
}
else
if
(
var
->
name_hint
==
"by"
)
{
stride_y
=
imm
->
value
;
by_var
=
var_expr
;
}
}
}
});
if
(
vectorize_size_max
!=
1
)
{
int
vectorize_hint
=
vectorize_size_max
;
AtomicAddVectorizePlanResult
res
=
{
1
,
false
,
0
};
AtomicAddVectorizePlanner
planner
;
res
=
planner
.
Plan
(
for_node
,
thread_var
,
thread_bounds
,
vectorize_hint
);
vectorize_hint
=
res
.
vector_size
;
if
(
vectorize_hint
==
1
||
stride_x
==
-
1
||
stride_y
==
-
1
||
!
bx_var
.
defined
()
||
!
by_var
.
defined
())
return
for_node
;
auto
rewriter
=
AtomicAddVectorizeRewriter
(
res
,
thread_var
,
by_var
,
bx_var
,
thread_bounds
,
stride_y
,
stride_x
);
return
Downcast
<
For
>
(
rewriter
(
for_node
));
}
else
{
return
for_node
;
return
for_node
;
}
auto
rewriter
=
AtomicAddVectorizeRewriter
(
res
);
return
Downcast
<
For
>
(
rewriter
(
for_node
));
}
}
}
// namespace tl
}
// namespace tl
...
...
src/transform/atomicadd_vectorize.h
View file @
340bfc50
...
@@ -6,16 +6,53 @@
...
@@ -6,16 +6,53 @@
#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_
#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_
#define TVM_TL_ATOMICADD_VECTORIZE_H_
#define TVM_TL_ATOMICADD_VECTORIZE_H_
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "atomicadd_vectorize.h"
#include "common/loop_vectorization_utils.h"
#include <numeric>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
using
namespace
tir
;
using
namespace
tir
;
For
VectorizeAtomicAdd
(
const
For
&
for_node
,
const
Var
&
thread_var
,
For
VectorizeAtomicAdd
(
const
For
&
for_node
,
int
compute_capability
);
const
Range
&
thread_bounds
,
int
compute_capability
);
struct
AtomicAddVectorizePlanResult
{
int
vector_size
;
bool
dynamic
;
PrimExpr
condition
;
};
class
AtomicAddVectorizePlanner
:
public
arith
::
IRVisitorWithAnalyzer
{
public:
AtomicAddVectorizePlanner
();
AtomicAddVectorizePlanResult
Plan
(
const
For
&
node
,
int
compute_capability
);
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
;
void
VisitExpr_
(
const
CallNode
*
node
)
final
;
int
GetVectorizeSizeMax
(
int
compute_capability
,
DataType
dtype
);
void
UpdateVectorSize
(
const
Array
<
PrimExpr
>
&
indices
,
const
Buffer
&
buffer
);
const
ForNode
*
inner_for_
=
nullptr
;
bool
has_nonlocal_memory_access_
=
false
;
int
vector_size_
=
4
;
int
max_vector_size
=
1
;
bool
dynamic_
=
false
;
PrimExpr
condition_
;
};
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment