Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
549416f7
Commit
549416f7
authored
Jan 11, 2025
by
LeiWang1999
Browse files
Merge branch 'main' of
https://github.com/microsoft/TileLang
into main
parents
4d63633a
7fad4e88
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
853 additions
and
711 deletions
+853
-711
src/transform/legalize_vectorized_loop.cc
src/transform/legalize_vectorized_loop.cc
+9
-8
src/transform/loop_partition.cc
src/transform/loop_partition.cc
+29
-24
src/transform/loop_partition.h
src/transform/loop_partition.h
+5
-4
src/transform/loop_vectorize.cc
src/transform/loop_vectorize.cc
+60
-48
src/transform/loop_vectorize.h
src/transform/loop_vectorize.h
+7
-7
src/transform/lower_hopper_intrin.cc
src/transform/lower_hopper_intrin.cc
+41
-30
src/transform/lower_tile_op.cc
src/transform/lower_tile_op.cc
+66
-45
src/transform/multi_version_buffer_rewriter.cc
src/transform/multi_version_buffer_rewriter.cc
+74
-59
src/transform/pipeline_planning.cc
src/transform/pipeline_planning.cc
+77
-55
src/transform/simplify.cc
src/transform/simplify.cc
+139
-127
src/transform/thread_partial_sync.cc
src/transform/thread_partial_sync.cc
+69
-56
src/transform/warp_specialized_rewriter.cc
src/transform/warp_specialized_rewriter.cc
+218
-151
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+0
-1
testing/python/kernel/test_tilelang_gemm.py
testing/python/kernel/test_tilelang_gemm.py
+6
-12
testing/python/kernel/test_tilelang_gemm_mma_intrinsic.py
testing/python/kernel/test_tilelang_gemm_mma_intrinsic.py
+1
-2
testing/python/primitives/test_tilelang_primitives_mma.py
testing/python/primitives/test_tilelang_primitives_mma.py
+13
-18
tilelang/intrinsics/__init__.py
tilelang/intrinsics/__init__.py
+1
-2
tilelang/intrinsics/mma_macro_generator.py
tilelang/intrinsics/mma_macro_generator.py
+31
-47
tilelang/language/allocate.py
tilelang/language/allocate.py
+1
-0
tilelang/language/copy.py
tilelang/language/copy.py
+6
-15
No files found.
src/transform/legalize_vectorized_loop.cc
View file @
549416f7
...
...
@@ -30,8 +30,8 @@
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
...
...
@@ -43,25 +43,26 @@ using arith::IRMutatorWithAnalyzer;
// Class to legalize vectorized loops by transforming them appropriately
class
LoopVectorizedLegalizer
:
IRMutatorWithAnalyzer
{
public:
public:
// Static method to substitute and transform the given PrimFunc
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
arith
::
Analyzer
analyzer
;
// Create an instance of the legalizer with the analyzer
LoopVectorizedLegalizer
substituter
(
&
analyzer
);
// Get a mutable copy of the function node
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
// Apply the legalizer to the function body
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
return
f
;
}
private:
private:
// Constructor initializing the base class with the analyzer
LoopVectorizedLegalizer
(
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
)
{}
LoopVectorizedLegalizer
(
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
)
{}
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// Visit and potentially modify the loop node
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
// If the loop is not vectorized, proceed with the default behavior
...
...
@@ -90,5 +91,5 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
TVM_REGISTER_GLOBAL
(
"tl.transform.LegalizeVectorizedLoop"
)
.
set_body_typed
(
LegalizeVectorizedLoop
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/loop_partition.cc
View file @
549416f7
...
...
@@ -32,29 +32,32 @@ namespace tl {
using
namespace
tir
;
class
BufferIndiceSimplify
:
public
StmtExprMutator
{
public:
BufferIndiceSimplify
(
arith
::
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{}
public:
BufferIndiceSimplify
(
arith
::
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{}
private:
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
private:
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
auto
visited
=
StmtExprMutator
::
VisitExpr_
(
node
);
auto
n
=
visited
.
as
<
BufferLoad
>
().
value
();
auto
nptr
=
n
.
CopyOnWrite
();
nptr
->
indices
=
nptr
->
indices
.
Map
([
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
nptr
->
indices
=
nptr
->
indices
.
Map
(
[
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
return
n
;
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
auto
visited
=
StmtExprMutator
::
VisitStmt_
(
node
);
auto
n
=
visited
.
as
<
BufferStore
>
().
value
();
auto
nptr
=
n
.
CopyOnWrite
();
nptr
->
indices
=
nptr
->
indices
.
Map
([
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
nptr
->
indices
=
nptr
->
indices
.
Map
(
[
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
return
n
;
}
arith
::
Analyzer
*
analyzer_
;
arith
::
Analyzer
*
analyzer_
;
};
// Rewrite the parallel loop into a common loop, which is mapped to threads
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
)
{
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
)
{
ICHECK
(
loop_layout
.
defined
());
ICHECK
(
thread_var
.
defined
());
int
old_loop_depth
=
loop_layout
->
InputDim
();
...
...
@@ -71,7 +74,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
Map
<
Var
,
PrimExpr
>
vmap
;
Stmt
body
=
op
;
auto
inv_loop
=
loop_layout
->
Inverse
();
auto
indices
=
inv_loop
->
Forward
(
vars
.
Map
([](
const
Var
&
v
)
{
return
PrimExpr
(
v
);
}));
auto
indices
=
inv_loop
->
Forward
(
vars
.
Map
([](
const
Var
&
v
)
{
return
PrimExpr
(
v
);
}));
for
(
int
i
=
0
;
i
<
old_loop_depth
;
i
++
)
{
ICHECK
(
body
.
as
<
For
>
().
defined
());
For
loop
=
body
.
as
<
For
>
().
value
();
...
...
@@ -82,8 +86,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
// substitute and re-construct the serial loop
body
=
Substitute
(
body
,
vmap
);
for
(
int
i
=
new_loop_depth
-
1
;
i
>=
0
;
i
--
)
{
body
=
For
(
vars
[
i
],
make_zero
(
vars
[
i
]
->
dtype
),
inv_loop
->
InputShape
()[
i
],
ForKind
::
kSerial
,
body
);
body
=
For
(
vars
[
i
],
make_zero
(
vars
[
i
]
->
dtype
),
inv_loop
->
InputShape
()[
i
],
ForKind
::
kSerial
,
body
);
analyzer
->
Bind
(
vars
[
i
],
Range
(
0
,
inv_loop
->
InputShape
()[
i
]));
}
...
...
@@ -95,11 +99,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
}
class
LoopPramaUnroller
:
public
StmtExprMutator
{
public:
public:
LoopPramaUnroller
()
=
default
;
private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
if
(
node
->
kind
==
ForKind
::
kSerial
)
{
For
new_for
=
GetRef
<
For
>
(
node
);
auto
for_ptr
=
new_for
.
CopyOnWrite
();
...
...
@@ -112,7 +116,7 @@ class LoopPramaUnroller : public StmtExprMutator {
};
class
LoopPartitioner
:
public
StmtExprVisitor
{
public:
public:
LoopPartitioner
()
=
default
;
Fragment
Partition
(
For
op
,
int
num_thread
,
int
vectorize_size
)
{
...
...
@@ -129,17 +133,18 @@ class LoopPartitioner : public StmtExprVisitor {
ICHECK
(
loop_size_full
%
vectorize_size
==
0
);
PrimExpr
access_idx
=
FloorDiv
(
flattened
,
vectorize_size
);
PrimExpr
thd
=
FloorMod
(
access_idx
,
num_thread
);
PrimExpr
idx
=
FloorDiv
(
access_idx
,
num_thread
)
*
vectorize_size
+
FloorMod
(
flattened
,
vectorize_size
);
PrimExpr
idx
=
FloorDiv
(
access_idx
,
num_thread
)
*
vectorize_size
+
FloorMod
(
flattened
,
vectorize_size
);
return
Fragment
(
loop_vars_
,
{
idx
},
{
thd
},
{});
}
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
if
(
node
->
kind
==
ForKind
::
kParallel
)
{
body_
=
node
->
body
;
loop_vars_
.
push_back
(
IterVar
(
Range
::
FromMinExtent
(
node
->
min
,
node
->
extent
),
node
->
loop_var
,
IterVarType
::
kDataPar
));
loop_vars_
.
push_back
(
IterVar
(
Range
::
FromMinExtent
(
node
->
min
,
node
->
extent
),
node
->
loop_var
,
IterVarType
::
kDataPar
));
}
StmtExprVisitor
::
VisitStmt_
(
node
);
}
...
...
@@ -160,5 +165,5 @@ For LoopPragmaUnroll(For stmt) {
return
unrolled
;
}
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/loop_partition.h
View file @
549416f7
...
...
@@ -36,13 +36,14 @@ namespace tl {
using
namespace
tir
;
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
);
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
);
Fragment
PlanLoopPartition
(
For
op
,
size_t
num_thread
,
int
vectorize_size
);
For
LoopPragmaUnroll
(
For
stmt
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
#endif
// TVM_TL_LOOP_PARTITION_H_
#endif // TVM_TL_LOOP_PARTITION_H_
src/transform/loop_vectorize.cc
View file @
549416f7
...
...
@@ -30,10 +30,10 @@
#include <numeric>
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
namespace
tvm
{
...
...
@@ -48,10 +48,10 @@ struct VectorizePlanResult {
};
class
VectorizePlanner
:
public
arith
::
IRVisitorWithAnalyzer
{
public:
public:
VectorizePlanner
()
=
default
;
int
Plan
(
const
For
&
node
)
{
int
Plan
(
const
For
&
node
)
{
this
->
operator
()(
node
);
// Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1;
...
...
@@ -62,18 +62,19 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
PrimExpr
GetCondition
()
{
return
condition_
;
}
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
inner_for_
=
node
;
iter_map_
.
Set
(
node
->
loop_var
,
Range
(
node
->
min
,
node
->
extent
));
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
void
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
void
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
node
->
buffer
.
scope
()
==
"shared.dyn"
)
has_nonlocal_memory_access_
=
true
;
if
(
node
->
buffer
->
shape
.
size
()
==
1
&&
node
->
buffer
->
shape
[
0
].
as
<
IntImmNode
>
()
->
value
==
1
)
{
if
(
node
->
buffer
->
shape
.
size
()
==
1
&&
node
->
buffer
->
shape
[
0
].
as
<
IntImmNode
>
()
->
value
==
1
)
{
// TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register.
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
...
...
@@ -82,7 +83,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
node
->
buffer
.
scope
()
==
"shared.dyn"
)
has_nonlocal_memory_access_
=
true
;
...
...
@@ -90,12 +91,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
void
VisitStmt_
(
const
IfThenElseNode
*
node
)
final
{
void
VisitStmt_
(
const
IfThenElseNode
*
node
)
final
{
CheckConditionVectorized
(
node
->
condition
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
void
VisitExpr_
(
const
CallNode
*
node
)
final
{
void
VisitExpr_
(
const
CallNode
*
node
)
final
{
if
(
node
->
op
==
builtin
::
if_then_else
())
{
CheckConditionVectorized
(
node
->
args
[
0
]);
}
else
if
(
node
->
op
==
builtin
::
call_extern
())
{
...
...
@@ -105,16 +106,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
void
CheckConditionVectorized
(
const
PrimExpr
&
cond
)
{
void
CheckConditionVectorized
(
const
PrimExpr
&
cond
)
{
// TODO: perform some checks here
}
void
UpdateVectorSize
(
const
Array
<
PrimExpr
>
indices
,
const
Buffer
&
buffer
)
{
if
(
!
inner_for_
)
return
;
void
UpdateVectorSize
(
const
Array
<
PrimExpr
>
indices
,
const
Buffer
&
buffer
)
{
if
(
!
inner_for_
)
return
;
auto
extent_ptr
=
inner_for_
->
extent
.
as
<
IntImmNode
>
();
if
(
!
extent_ptr
)
return
;
if
(
!
extent_ptr
)
return
;
const
DataType
&
access_type
=
buffer
->
dtype
;
const
DataType
&
access_type
=
buffer
->
dtype
;
// i // 2, i % 8 can also be vectorized as factor 16
int
max_vector_size
=
128
/
access_type
.
bits
();
// so we should disable this GCD optimization
...
...
@@ -122,7 +125,8 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
auto
last_dim
=
buffer
->
shape
.
back
();
auto
mod_set
=
analyzer_
.
modular_set
(
last_dim
);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block conditionally tail vectorize
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if
(
buffer
->
shape
.
back
().
as
<
IntImmNode
>
())
{
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
coeff
);
...
...
@@ -142,8 +146,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
elem_offset
=
elem_offset
+
indices
[
i
]
*
stride
;
stride
=
stride
*
buffer
->
shape
[
i
];
}
while
(
!
IndiceCanVectorize
(
elem_offset
,
inner_for_
->
loop_var
,
inner_for_
->
extent
,
vector_size_
,
&
analyzer_
))
{
while
(
!
IndiceCanVectorize
(
elem_offset
,
inner_for_
->
loop_var
,
inner_for_
->
extent
,
vector_size_
,
&
analyzer_
))
{
vector_size_
/=
2
;
}
}
else
if
(
vector_size_
<=
vector_load_bits_max_
/
buffer
->
dtype
.
bits
())
{
...
...
@@ -156,7 +161,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
static
const
int
vector_load_bits_max_
=
128
;
const
ForNode
*
inner_for_
;
const
ForNode
*
inner_for_
;
Map
<
Var
,
Range
>
iter_map_
;
bool
has_nonlocal_memory_access_
=
false
;
int
vector_size_
=
128
;
...
...
@@ -166,12 +171,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
};
class
VectorizeDynamicCallRemover
:
public
StmtExprMutator
{
public:
public:
VectorizeDynamicCallRemover
(
Var
inner_var
,
int
vector_size
)
:
inner_var_
(
inner_var
),
vector_size_
(
vector_size
)
{}
private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
Map
<
Var
,
PrimExpr
>
vmap
;
...
...
@@ -191,15 +196,16 @@ class VectorizeDynamicCallRemover : public StmtExprMutator {
};
class
VectorizeRewriter
:
public
StmtExprMutator
{
public:
public:
VectorizeRewriter
(
VectorizePlanResult
plan
)
:
vector_size_
(
plan
.
vector_size
),
condition_
(
plan
.
condition
),
dynamic_
(
plan
.
dynamic
)
{}
:
vector_size_
(
plan
.
vector_size
),
condition_
(
plan
.
condition
),
dynamic_
(
plan
.
dynamic
)
{}
private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
inner_for_
=
node
;
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
if
(
inner_for_
==
node
)
{
// rewrite the innermost loop
if
(
inner_for_
==
node
)
{
// rewrite the innermost loop
For
fnode
=
ret
.
as
<
For
>
().
value
();
auto
old_var
=
fnode
->
loop_var
;
auto
extent_ptr
=
as_const_int
(
fnode
->
extent
);
...
...
@@ -208,7 +214,7 @@ class VectorizeRewriter : public StmtExprMutator {
ICHECK
(
extent
%
vector_size_
==
0
)
<<
"extent: "
<<
extent
<<
" vector_size_: "
<<
vector_size_
;
ICHECK
(
is_zero
(
fnode
->
min
));
if
(
!
dynamic_
)
{
// check dynamic shape
if
(
!
dynamic_
)
{
// check dynamic shape
if
(
extent
==
vector_size_
)
{
fnode
.
CopyOnWrite
()
->
kind
=
ForKind
::
kVectorized
;
return
fnode
;
...
...
@@ -219,8 +225,8 @@ class VectorizeRewriter : public StmtExprMutator {
vmap
.
Set
(
fnode
->
loop_var
,
outer_var
*
vector_size_
+
inner_var
);
Stmt
body
=
Substitute
(
fnode
->
body
,
vmap
);
body
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
body
);
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
return
body
;
}
}
else
{
...
...
@@ -237,11 +243,13 @@ class VectorizeRewriter : public StmtExprMutator {
VectorizeDynamicCallRemover
remover
(
inner_var
,
vector_size_
);
body
=
remover
(
body
);
For
vectorize_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
body
);
For
serial_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kSerial
,
body
);
For
vectorize_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
body
);
For
serial_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kSerial
,
body
);
body
=
IfThenElse
(
condition
,
vectorize_for
,
serial_for
);
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
return
body
;
}
}
else
{
...
...
@@ -249,15 +257,15 @@ class VectorizeRewriter : public StmtExprMutator {
}
}
const
ForNode
*
inner_for_
;
const
ForNode
*
inner_for_
;
const
int
vector_size_
;
const
PrimExpr
condition_
;
const
bool
dynamic_
;
};
int
GetVectorizeSize
(
const
For
&
loop
)
{
return
VectorizePlanner
().
Plan
(
loop
);
}
int
GetVectorizeSize
(
const
For
&
loop
)
{
return
VectorizePlanner
().
Plan
(
loop
);
}
VectorizePlanResult
GetVectorizePlanResult
(
const
For
&
loop
)
{
VectorizePlanResult
GetVectorizePlanResult
(
const
For
&
loop
)
{
VectorizePlanner
planner
;
int
vector_size
=
planner
.
Plan
(
loop
);
bool
dynamic
=
planner
.
GetDynamic
();
...
...
@@ -265,16 +273,19 @@ VectorizePlanResult GetVectorizePlanResult(const For& loop) {
return
{
vector_size
,
dynamic
,
condition
};
}
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
int
target_vectorized_size
,
arith
::
Analyzer
*
analyzer
)
{
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
int
target_vectorized_size
,
arith
::
Analyzer
*
analyzer
)
{
ICHECK
(
target_vectorized_size
>=
1
);
if
(
target_vectorized_size
==
1
)
return
true
;
if
(
!
analyzer
->
CanProveEqual
(
FloorMod
(
iter_var_size
,
target_vectorized_size
),
0
))
return
false
;
if
(
target_vectorized_size
==
1
)
return
true
;
if
(
!
analyzer
->
CanProveEqual
(
FloorMod
(
iter_var_size
,
target_vectorized_size
),
0
))
return
false
;
Var
v0
(
"v0"
),
v1
(
"v1"
);
analyzer
->
Bind
(
v0
,
Range
(
0
,
target_vectorized_size
));
analyzer
->
Bind
(
v1
,
Range
(
0
,
FloorDiv
(
iter_var_size
,
target_vectorized_size
)));
PrimExpr
expr_transformed
=
analyzer
->
Simplify
(
Substitute
(
expr
,
{{
var
,
v0
+
v1
*
target_vectorized_size
}}));
PrimExpr
expr_transformed
=
analyzer
->
Simplify
(
Substitute
(
expr
,
{{
var
,
v0
+
v1
*
target_vectorized_size
}}));
Vectorizer
vectorizer
(
v0
,
IntImm
(
v0
->
dtype
,
target_vectorized_size
));
PrimExpr
expr_vectorized
=
vectorizer
.
VisitExpr
(
expr_transformed
);
...
...
@@ -290,16 +301,17 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int targ
}
}
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
)
{
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
)
{
VectorizePlanResult
res
{
128
,
false
,
0
};
if
(
vectorize_hint
<=
0
)
{
res
=
GetVectorizePlanResult
(
loop
);
vectorize_hint
=
res
.
vector_size
;
}
if
(
vectorize_hint
==
1
)
return
loop
;
if
(
vectorize_hint
==
1
)
return
loop
;
auto
rewriter
=
VectorizeRewriter
(
res
);
return
Downcast
<
For
>
(
rewriter
(
loop
));
}
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/loop_vectorize.h
View file @
549416f7
...
...
@@ -35,13 +35,13 @@ namespace tl {
using
namespace
tir
;
int
GetVectorizeSize
(
const
For
&
loop
);
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
=
-
1
);
int
GetVectorizeSize
(
const
For
&
loop
);
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
=
-
1
);
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
int
target_vectorized_size
,
arith
::
Analyzer
*
analyzer
);
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
int
target_vectorized_size
,
arith
::
Analyzer
*
analyzer
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
#endif
// TVM_TL_LOOP_VECTORIZE_H_
#endif // TVM_TL_LOOP_VECTORIZE_H_
src/transform/lower_hopper_intrin.cc
View file @
549416f7
...
...
@@ -37,15 +37,15 @@ namespace tl {
using
namespace
tir
;
class
LowerHopperIntrin
:
public
StmtExprMutator
{
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
LowerHopperIntrin
substituter
;
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
for
(
auto
[
call
,
var
]
:
substituter
.
desc_map_
)
{
// Should allocate 128 bytes for TensorMap on stack
Call
alloc_desc
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_stack_alloca
(),
{
StringImm
(
"arg_value"
),
16
});
Call
alloc_desc
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_stack_alloca
(),
{
StringImm
(
"arg_value"
),
16
});
Array
<
PrimExpr
>
init_desc_args
;
if
(
call
->
op
.
same_as
(
CreateTMADescriptorOp
()))
{
init_desc_args
.
push_back
(
StringImm
(
tvm_tensormap_create_tiled
));
...
...
@@ -55,15 +55,19 @@ class LowerHopperIntrin : public StmtExprMutator {
CHECK
(
0
)
<<
call
->
op
;
}
init_desc_args
.
push_back
(
var
);
init_desc_args
.
insert
(
init_desc_args
.
end
(),
call
->
args
.
begin
(),
call
->
args
.
end
());
Call
init_desc
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_call_packed
(),
init_desc_args
);
fptr
->
body
=
LetStmt
(
var
,
alloc_desc
,
SeqStmt
({
Evaluate
(
init_desc
),
fptr
->
body
}));
init_desc_args
.
insert
(
init_desc_args
.
end
(),
call
->
args
.
begin
(),
call
->
args
.
end
());
Call
init_desc
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_call_packed
(),
init_desc_args
);
fptr
->
body
=
LetStmt
(
var
,
alloc_desc
,
SeqStmt
({
Evaluate
(
init_desc
),
fptr
->
body
}));
}
return
f
;
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
// Insert the prefetch TMA descriptor statement TO the beginning of the kernel
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
// Insert the prefetch TMA descriptor statement TO the beginning of the
// kernel
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
if
(
iv
->
thread_tag
==
"threadIdx.x"
)
{
...
...
@@ -73,18 +77,22 @@ class LowerHopperIntrin : public StmtExprMutator {
}
else
{
Array
<
Stmt
>
stmt_seq
;
if
(
!
init_mbarrier_calls_
.
empty
())
{
auto
alloc_mbarrier
=
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
create_barriers
(),
{
static_cast
<
int
>
(
init_mbarrier_calls_
.
size
())}));
auto
alloc_mbarrier
=
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
create_barriers
(),
{
static_cast
<
int
>
(
init_mbarrier_calls_
.
size
())}));
stmt_seq
.
push_back
(
alloc_mbarrier
);
}
auto
stmts
=
prefetch_calls_
;
stmts
.
insert
(
stmts
.
end
(),
init_mbarrier_calls_
.
begin
(),
init_mbarrier_calls_
.
end
());
auto
init_stmt
=
IfThenElse
(
EQ
(
iv
->
var
,
0
),
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
]);
stmts
.
insert
(
stmts
.
end
(),
init_mbarrier_calls_
.
begin
(),
init_mbarrier_calls_
.
end
());
auto
init_stmt
=
IfThenElse
(
EQ
(
iv
->
var
,
0
),
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
]);
stmt_seq
.
push_back
(
init_stmt
);
if
(
!
init_mbarrier_calls_
.
empty
())
{
Stmt
mem_sync
=
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
tvm_storage_sync
(),
{
StringImm
(
"shared"
)}));
Stmt
mem_sync
=
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
tvm_storage_sync
(),
{
StringImm
(
"shared"
)}));
stmt_seq
.
push_back
(
mem_sync
);
}
stmt_seq
.
push_back
(
body
);
...
...
@@ -98,7 +106,7 @@ class LowerHopperIntrin : public StmtExprMutator {
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
PrimExpr
VisitExpr_
(
const
CallNode
*
call
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
call
)
final
{
if
(
call
->
op
.
same_as
(
CreateTMADescriptorOp
())
||
call
->
op
.
same_as
(
CreateTMAIm2ColDescriptorOp
()))
{
Var
var
;
...
...
@@ -107,10 +115,12 @@ class LowerHopperIntrin : public StmtExprMutator {
var
=
iter
->
second
;
}
else
{
String
name
=
call
->
args
[
2
].
as
<
Var
>
().
value
()
->
name_hint
;
var
=
Var
(
name
+
"_desc"
,
PointerType
(
PrimType
(
cuTensorMapType
()),
"grid_constant"
));
var
=
Var
(
name
+
"_desc"
,
PointerType
(
PrimType
(
cuTensorMapType
()),
"grid_constant"
));
desc_map_
[
GetRef
<
Call
>
(
call
)]
=
var
;
prefetch_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
{
StringImm
(
"tl::prefetch_tma_descriptor"
),
var
})));
prefetch_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
{
StringImm
(
"tl::prefetch_tma_descriptor"
),
var
})));
}
return
var
;
}
else
if
(
call
->
op
.
same_as
(
CreateListofMBarrierOp
()))
{
...
...
@@ -118,24 +128,25 @@ class LowerHopperIntrin : public StmtExprMutator {
int
num_barriers
=
static_cast
<
int
>
(
call
->
args
.
size
());
for
(
int
i
=
0
;
i
<
num_barriers
;
i
++
)
{
PrimExpr
mbarrier
=
Call
(
DataType
::
Handle
(),
GetMBarrierOp
(),
{
i
});
init_mbarrier_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
{
mbarrier
,
call
->
args
[
i
]})));
init_mbarrier_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
{
mbarrier
,
call
->
args
[
i
]})));
}
return
0
;
}
else
if
(
call
->
op
.
same_as
(
SyncThreadsPartialOp
()))
{
int
barrier_id
=
init_mbarrier_calls_
.
size
();
PrimExpr
mbarrier
=
Call
(
DataType
::
Handle
(),
GetMBarrierOp
(),
{
barrier_id
});
init_mbarrier_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
{
mbarrier
,
call
->
args
[
0
]})));
PrimExpr
mbarrier
=
Call
(
DataType
::
Handle
(),
GetMBarrierOp
(),
{
barrier_id
});
init_mbarrier_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
{
mbarrier
,
call
->
args
[
0
]})));
return
Call
(
DataType
::
Handle
(),
SyncThreadsPartialOp
(),
{
mbarrier
});
}
else
{
return
StmtExprMutator
::
VisitExpr_
(
call
);
}
}
private:
private:
Array
<
Stmt
>
prefetch_calls_
;
Array
<
Stmt
>
init_mbarrier_calls_
;
std
::
unordered_map
<
Call
,
Var
,
StructuralHash
,
ExprDeepEqual
>
desc_map_
;
...
...
@@ -154,5 +165,5 @@ tvm::transform::Pass LowerHopperIntrin() {
TVM_REGISTER_GLOBAL
(
"tl.transform.LowerHopperIntrin"
)
.
set_body_typed
(
LowerHopperIntrin
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/lower_tile_op.cc
View file @
549416f7
...
...
@@ -27,10 +27,10 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
namespace
tvm
{
...
...
@@ -38,8 +38,9 @@ namespace tl {
using
namespace
tir
;
static
Buffer
makeBufferWithLayout
(
const
Buffer
&
buffer
,
const
Layout
&
layout
)
{
const
auto
*
ptr_type
=
TVM_TYPE_AS
(
buffer
->
data
->
type_annotation
,
PointerTypeNode
);
static
Buffer
makeBufferWithLayout
(
const
Buffer
&
buffer
,
const
Layout
&
layout
)
{
const
auto
*
ptr_type
=
TVM_TYPE_AS
(
buffer
->
data
->
type_annotation
,
PointerTypeNode
);
Type
new_type
;
// convert fragments to normal local buffer
if
(
ptr_type
->
storage_scope
==
"local.fragment"
)
{
...
...
@@ -53,32 +54,33 @@ static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) {
}
else
{
new_var
=
Var
(
buffer
->
data
->
name_hint
,
new_type
);
}
return
Buffer
(
new_var
,
buffer
->
dtype
,
layout
->
OutputShape
(),
{},
buffer
->
elem_offset
,
buffer
->
name
,
buffer
->
data_alignment
,
buffer
->
offset_factor
,
buffer
->
buffer_type
);
return
Buffer
(
new_var
,
buffer
->
dtype
,
layout
->
OutputShape
(),
{},
buffer
->
elem_offset
,
buffer
->
name
,
buffer
->
data_alignment
,
buffer
->
offset_factor
,
buffer
->
buffer_type
);
}
class
LowerTileOpPass
:
arith
::
IRMutatorWithAnalyzer
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
arith
::
Analyzer
analyzer
;
LowerTileOpPass
substituter
(
&
analyzer
);
// Trace the buffer map for tvm_access_ptr
substituter
.
buffer_map_
.
insert
(
f
->
buffer_map
.
begin
(),
f
->
buffer_map
.
end
());
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
substituter
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
<<
"LowerTileOpPass: Require the target attribute"
;
substituter
.
target_
=
target
.
value
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
return
f
;
}
private:
private:
using
arith
::
IRMutatorWithAnalyzer
::
IRMutatorWithAnalyzer
;
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
// Record the mapping from buffer data var to buffer for later lookup
for
(
auto
buffer
:
op
->
alloc_buffers
)
{
buffer_map_
.
insert
({
buffer
->
data
,
buffer
});
...
...
@@ -91,7 +93,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
}
Map
<
Var
,
Layout
>
vmap
;
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
auto
layout_map
=
op
->
annotations
.
at
(
attr
::
kLayoutMap
).
as
<
Map
<
Buffer
,
Layout
>>
().
value
();
auto
layout_map
=
op
->
annotations
.
at
(
attr
::
kLayoutMap
)
.
as
<
Map
<
Buffer
,
Layout
>>
()
.
value
();
for
(
auto
[
buffer
,
layout
]
:
layout_map
)
{
buffer_remap_
.
Set
(
buffer
,
makeBufferWithLayout
(
buffer
,
layout
));
layout_map_
.
Set
(
buffer
,
layout
);
...
...
@@ -105,7 +109,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
block_ptr
->
alloc_buffers
.
Set
(
i
,
buffer_remap_
[
buffer
]);
}
}
for
(
const
auto
&
buffer
:
workspaces_
)
block_ptr
->
alloc_buffers
.
push_back
(
buffer
);
for
(
const
auto
&
buffer
:
workspaces_
)
block_ptr
->
alloc_buffers
.
push_back
(
buffer
);
workspaces_
.
clear
();
block_ptr
->
annotations
.
erase
(
attr
::
kLayoutMap
);
return
block
;
...
...
@@ -113,18 +118,19 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
int
CheckAndGetBufferRowSize
(
Buffer
buffer
)
{
CHECK
(
buffer
->
shape
.
size
()
>=
2
)
<<
"The dimension of Buffer
\"
"
<<
buffer
->
name
<<
"
\"
with shape "
<<
buffer
->
shape
<<
" should be at least 2"
;
<<
"The dimension of Buffer
\"
"
<<
buffer
->
name
<<
"
\"
with shape "
<<
buffer
->
shape
<<
" should be at least 2"
;
auto
dim
=
buffer
->
shape
.
size
();
auto
buffer_row_size
=
buffer
->
shape
[
dim
-
1
].
as
<
IntImmNode
>
()
->
value
;
return
buffer_row_size
;
}
PrimExpr
HandleAccessPtrAndOffset
(
PrimExpr
access_ptr
,
Optional
<
PrimExpr
>
offset
=
NullOpt
,
PrimExpr
HandleAccessPtrAndOffset
(
PrimExpr
access_ptr
,
Optional
<
PrimExpr
>
offset
=
NullOpt
,
DataType
dtype
=
DataType
::
Int
(
32
))
{
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
accumulate it to
// smem_offset
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
//
accumulate it to
smem_offset
CHECK
(
access_ptr
->
IsInstance
<
CallNode
>
())
<<
"Invalid access ptr for permuted layout: "
<<
access_ptr
;
auto
access_ptr_call
=
Downcast
<
Call
>
(
access_ptr
);
...
...
@@ -136,8 +142,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
Array
<
PrimExpr
>
shape
=
load
->
buffer
->
shape
;
CHECK_EQ
(
indices
.
size
(),
shape
.
size
())
<<
"Indices size and shape size must match for general N-dimensional buffer "
<<
"but got indices size: "
<<
indices
.
size
()
<<
" and shape size: "
<<
shape
.
size
();
<<
"Indices size and shape size must match for general N-dimensional "
"buffer "
<<
"but got indices size: "
<<
indices
.
size
()
<<
" and shape size: "
<<
shape
.
size
();
PrimExpr
elem_offset
=
0
;
PrimExpr
stride
=
1
;
...
...
@@ -147,13 +155,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
stride
*=
shape
[
i
];
}
PrimExpr
smem_offset
=
elem_offset
+
(
offset
.
defined
()
?
offset
.
value
()
:
0
);
PrimExpr
smem_offset
=
elem_offset
+
(
offset
.
defined
()
?
offset
.
value
()
:
0
);
auto
new_buffer
=
buffer_remap_
[
load
->
buffer
];
auto
buffer_map_iter
=
buffer_map_
.
find
(
Downcast
<
Var
>
(
load
->
buffer
->
data
));
auto
buffer_map_iter
=
buffer_map_
.
find
(
Downcast
<
Var
>
(
load
->
buffer
->
data
));
CHECK
(
buffer_map_iter
!=
buffer_map_
.
end
())
<<
"The buffer corresponding to data Var "
<<
access_ptr_call
->
args
[
0
]
<<
" is not found"
;
<<
"The buffer corresponding to data Var "
<<
access_ptr_call
->
args
[
0
]
<<
" is not found"
;
int
buffer_row_size
=
CheckAndGetBufferRowSize
(
buffer_map_iter
->
second
);
(
void
)
buffer_row_size
;
...
...
@@ -163,11 +174,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
PrimExpr
remaining_offset
=
smem_offset
;
for
(
int
i
=
static_cast
<
int
>
(
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
multi_dim_indices
.
insert
(
multi_dim_indices
.
begin
(),
floormod
(
remaining_offset
,
shape
[
i
]));
multi_dim_indices
.
insert
(
multi_dim_indices
.
begin
(),
floormod
(
remaining_offset
,
shape
[
i
]));
remaining_offset
=
floordiv
(
remaining_offset
,
shape
[
i
]);
}
auto
forward_indices
=
layout_map_
[
load
->
buffer
]
->
Forward
(
multi_dim_indices
);
auto
forward_indices
=
layout_map_
[
load
->
buffer
]
->
Forward
(
multi_dim_indices
);
PrimExpr
new_offset
=
0
;
PrimExpr
stride_offset
=
1
;
for
(
int
i
=
static_cast
<
int
>
(
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
...
...
@@ -191,8 +204,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
access_ptr_call
;
}
PrimExpr
VisitExpr_
(
const
tir
::
CallNode
*
op
)
final
{
if
(
!
op
->
op
.
same_as
(
builtin
::
ptx_ldmatrix
())
&&
!
op
->
op
.
same_as
(
builtin
::
mma_store
()))
{
PrimExpr
VisitExpr_
(
const
tir
::
CallNode
*
op
)
final
{
if
(
!
op
->
op
.
same_as
(
builtin
::
ptx_ldmatrix
())
&&
!
op
->
op
.
same_as
(
builtin
::
mma_store
()))
{
return
Downcast
<
Call
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
}
else
{
is_ptx_
=
true
;
...
...
@@ -212,15 +226,18 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
address_of_call
->
args
[
0
]);
if
(
buffer_remap_
.
count
(
load
->
buffer
))
{
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
smem_offset
,
call
->
dtype
);
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
smem_offset
,
call
->
dtype
);
auto
new_call
=
call
.
CopyOnWrite
();
new_call
->
args
.
Set
(
5
,
new_access_ptr
);
new_call
->
args
.
Set
(
6
,
IntImm
(
smem_offset
->
dtype
,
0
));
}
}
else
if
(
call
->
op
.
same_as
(
builtin
::
mma_store
()))
{
// because we will directly store result to Buffer instead of calling mma_store now
// because we will directly store result to Buffer instead of calling
// mma_store now
auto
access_ptr
=
call
->
args
[
2
];
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
NullOpt
,
call
->
dtype
);
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
NullOpt
,
call
->
dtype
);
auto
new_call
=
call
.
CopyOnWrite
();
new_call
->
args
.
Set
(
2
,
new_access_ptr
);
}
else
{
...
...
@@ -230,7 +247,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
call
;
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
Downcast
<
BufferLoad
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
if
(
is_ptx_
)
{
return
load
;
...
...
@@ -243,7 +260,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
load
;
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
Downcast
<
BufferStore
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
if
(
buffer_remap_
.
count
(
store
->
buffer
))
{
auto
new_indices
=
layout_map_
[
store
->
buffer
]
->
Forward
(
store
->
indices
);
...
...
@@ -253,36 +270,40 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
store
;
}
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
auto
var
=
Downcast
<
Var
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
if
(
buffer_data_to_buffer_
.
count
(
var
))
{
auto
buffer
=
buffer_data_to_buffer_
[
var
];
if
(
buffer_remap_
.
count
(
buffer
))
return
buffer_remap_
[
buffer
]
->
data
;
if
(
buffer_remap_
.
count
(
buffer
))
return
buffer_remap_
[
buffer
]
->
data
;
}
return
var
;
}
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
();
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
();
// Do not analysis the call node to the global function.
if
(
call
&&
call
->
op
.
as
<
GlobalVarNode
>
())
return
Downcast
<
Evaluate
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
auto
tile_op
=
ParseOperator
(
GetRef
<
Stmt
>
(
op
),
buffer_data_to_buffer_
);
if
(
tile_op
==
nullptr
)
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
if
(
tile_op
==
nullptr
)
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
AddWorkspaceCallback
callback
=
[
this
](
int
num_elem
,
DataType
dtype
)
{
auto
workspace
=
decl_buffer
({
PrimExpr
(
num_elem
)},
dtype
,
"workspace"
,
"shared.dyn"
);
auto
workspace
=
decl_buffer
({
PrimExpr
(
num_elem
)},
dtype
,
"workspace"
,
"shared.dyn"
);
workspaces_
.
push_back
(
workspace
);
return
workspace
.
access_ptr
(
2
);
// write
return
workspace
.
access_ptr
(
2
);
// write
};
auto
lowered
=
tile_op
->
Lower
(
LowerArgs
{
target_
,
thread_block_size_
,
thread_var_
,
callback
,
layout_map_
,
buffer_remap_
},
analyzer_
);
auto
lowered
=
tile_op
->
Lower
(
LowerArgs
{
target_
,
thread_block_size_
,
thread_var_
,
callback
,
layout_map_
,
buffer_remap_
},
analyzer_
);
return
IRMutatorWithAnalyzer
::
VisitStmt
(
lowered
);
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
ICHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
...
...
@@ -321,7 +342,7 @@ tvm::transform::Pass LowerTileOp() {
}
TVM_REGISTER_GLOBAL
(
"tl.transform.LowerTileOp"
).
set_body_typed
(
LowerTileOp
);
}
// namespace transform
}
// namespace transform
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/multi_version_buffer_rewriter.cc
View file @
549416f7
...
...
@@ -38,22 +38,23 @@ using namespace tir;
enum
class
Role
{
kConsumer
,
kProducer
,
kBoth
};
class
WarpSpecializedRoleMarker_
:
public
StmtVisitor
{
public:
public:
WarpSpecializedRoleMarker_
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
)
:
buffer_data_to_buffer_
(
buffer_data_to_buffer
)
{}
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
auto
it
=
map_
.
find
(
stmt
);
ICHECK
(
it
!=
map_
.
end
());
return
it
->
second
;
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
Role
role
=
Role
::
kConsumer
;
if
(
auto
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
role
=
Role
::
kProducer
;
has_bulk_copy_
=
true
;
}
...
...
@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
if
(
!
is_shared_store
)
{
SetRole
(
op
,
Role
::
kConsumer
);
return
;
...
...
@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
break
;
}
}
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
seq
[
0
]);
for
(
auto
stmt
:
op
->
seq
)
{
...
...
@@ -96,48 +99,48 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
then_case
);
if
(
op
->
else_case
.
defined
())
{
auto
role_else
=
GetRole
(
op
->
else_case
.
value
());
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
}
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
block
));
}
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
body
));
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
bool
HasProducer
()
{
return
has_simt_copy_
||
has_bulk_copy_
;
}
bool
HasSimtCopy
()
{
return
has_simt_copy_
;
}
private:
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
private:
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
bool
has_simt_copy_
=
false
;
bool
has_bulk_copy_
=
false
;
};
class
MultiVersionBufferRewriter
:
public
StmtExprMutator
{
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
auto
rewriter
=
MultiVersionBufferRewriter
();
rewriter
.
buffer_lca_
=
DetectBufferAccessLCA
(
f
);
for
(
auto
[
buffer
,
_
]
:
rewriter
.
buffer_lca_
)
{
...
...
@@ -148,40 +151,45 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
f
;
}
private:
private:
MultiVersionBufferRewriter
()
=
default
;
Array
<
Buffer
>
GetVersionedBuffers
(
Array
<
Stmt
>
seq_stmt
,
Array
<
Buffer
>
scoped_buffers
)
{
Array
<
Buffer
>
GetVersionedBuffers
(
Array
<
Stmt
>
seq_stmt
,
Array
<
Buffer
>
scoped_buffers
)
{
std
::
vector
<
Role
>
roles
;
Array
<
Array
<
BufferRegion
>>
reads
,
writes
;
auto
marker
=
WarpSpecializedRoleMarker_
(
buffer_data_to_buffer_
);
for
(
auto
stmt
:
seq_stmt
)
{
marker
(
stmt
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
);
auto
access
=
GetBlockAccessRegion
(
block
,
buffer_data_to_buffer_
);
reads
.
push_back
(
std
::
move
(
access
[
0
]));
writes
.
push_back
(
std
::
move
(
access
[
1
]));
roles
.
push_back
(
marker
.
GetRole
(
stmt
));
}
std
::
unordered_set
<
const
BufferNode
*>
consumer_used
,
producer_used
;
std
::
unordered_set
<
const
BufferNode
*>
consumer_used
,
producer_used
;
for
(
size_t
i
=
0
;
i
<
seq_stmt
.
size
();
i
++
)
{
if
(
roles
[
i
]
==
Role
::
kProducer
)
{
for
(
BufferRegion
br
:
writes
[
i
])
producer_used
.
insert
(
br
->
buffer
.
get
());
for
(
BufferRegion
br
:
writes
[
i
])
producer_used
.
insert
(
br
->
buffer
.
get
());
}
else
{
for
(
BufferRegion
br
:
reads
[
i
])
consumer_used
.
insert
(
br
->
buffer
.
get
());
for
(
BufferRegion
br
:
reads
[
i
])
consumer_used
.
insert
(
br
->
buffer
.
get
());
}
}
Array
<
Buffer
>
versioned_buffers
;
for
(
Buffer
buffer
:
scoped_buffers
)
{
if
(
consumer_used
.
count
(
buffer
.
get
())
&&
producer_used
.
count
(
buffer
.
get
()))
{
if
(
consumer_used
.
count
(
buffer
.
get
())
&&
producer_used
.
count
(
buffer
.
get
()))
{
versioned_buffers
.
push_back
(
buffer
);
}
}
return
versioned_buffers
;
}
static
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
static
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
if
(
new_buffer
->
strides
.
size
())
{
...
...
@@ -192,8 +200,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
Buffer
(
new_buffer
);
}
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
Block
block
=
block_realize
->
block
;
Array
<
Buffer
>
alloc_buffers
;
for
(
auto
buffer
:
block
->
alloc_buffers
)
{
...
...
@@ -209,24 +218,27 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
block_realize
;
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
auto
num_stages_anno
=
op
->
annotations
.
Get
(
"num_stages"
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
op
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
op
);
ICHECK
(
num_stages_anno
.
as
<
IntImmNode
>
());
int
num_stages
=
static_cast
<
int
>
(
num_stages_anno
.
as
<
IntImmNode
>
()
->
value
);
const
SeqStmtNode
*
pipeline_body_seq
=
op
->
body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline
should be SeqStmt, got "
<<
op
->
body
->
GetTypeKey
();
const
SeqStmtNode
*
pipeline_body_seq
=
op
->
body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline "
"
should be SeqStmt, got "
<<
op
->
body
->
GetTypeKey
();
Array
<
Buffer
>
scoped_buffers
=
{};
for
(
auto
[
buffer
,
stmt
]
:
buffer_lca_
)
{
if
(
stmt
.
defined
()
&&
stmt
.
value
().
get
()
==
op
)
scoped_buffers
.
push_back
(
buffer
);
if
(
stmt
.
defined
()
&&
stmt
.
value
().
get
()
==
op
)
scoped_buffers
.
push_back
(
buffer
);
}
Array
<
Buffer
>
versioned_buffers
=
GetVersionedBuffers
(
pipeline_body_seq
->
seq
,
scoped_buffers
);
Array
<
Buffer
>
versioned_buffers
=
GetVersionedBuffers
(
pipeline_body_seq
->
seq
,
scoped_buffers
);
for
(
auto
buffer
:
versioned_buffers
)
{
Var
buffer_var
=
buffer
->
data
;
...
...
@@ -239,33 +251,33 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
for_node
;
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
auto
it
=
buffer_remap_
.
find
(
load
->
buffer
);
if
(
it
==
buffer_remap_
.
end
())
{
return
std
::
move
(
load
);
}
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
load
.
CopyOnWrite
();
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
load
.
CopyOnWrite
();
n
->
buffer
=
new_buffer
;
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version_index_
);
return
std
::
move
(
load
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
BufferStore
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
auto
it
=
buffer_remap_
.
find
(
store
->
buffer
);
if
(
it
==
buffer_remap_
.
end
())
{
return
std
::
move
(
store
);
}
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
store
.
CopyOnWrite
();
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
store
.
CopyOnWrite
();
n
->
buffer
=
new_buffer
;
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version_index_
);
return
std
::
move
(
store
);
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
Call
call
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
return
RewriteBufferAccess
(
call
,
{
1
});
...
...
@@ -273,20 +285,23 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
call
;
}
PrimExpr
RewriteBufferAccess
(
const
Call
&
call
,
const
std
::
vector
<
int
>
arg_indices
)
{
auto
product
=
[](
const
Array
<
PrimExpr
>&
input
)
{
return
foldl
([](
PrimExpr
a
,
PrimExpr
b
,
Span
span
)
{
return
mul
(
a
,
b
,
span
);
},
make_const
(
DataType
::
Int
(
32
),
1
),
input
);
PrimExpr
RewriteBufferAccess
(
const
Call
&
call
,
const
std
::
vector
<
int
>
arg_indices
)
{
auto
product
=
[](
const
Array
<
PrimExpr
>
&
input
)
{
return
foldl
(
[](
PrimExpr
a
,
PrimExpr
b
,
Span
span
)
{
return
mul
(
a
,
b
,
span
);
},
make_const
(
DataType
::
Int
(
32
),
1
),
input
);
};
Array
<
PrimExpr
>
new_args
=
call
->
args
;
for
(
int
i
:
arg_indices
)
{
auto
buffer_var
=
Downcast
<
Var
>
(
call
->
args
[
i
]);
if
(
!
buffer_data_to_buffer_
.
count
(
buffer_var
))
continue
;
const
Buffer
&
buffer
=
buffer_data_to_buffer_
[
buffer_var
];
if
(
!
buffer_data_to_buffer_
.
count
(
buffer_var
))
continue
;
const
Buffer
&
buffer
=
buffer_data_to_buffer_
[
buffer_var
];
auto
it
=
buffer_remap_
.
find
(
buffer
);
if
(
it
!=
buffer_remap_
.
end
())
{
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
PrimExpr
&
old_index
=
call
->
args
[
i
+
1
];
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
PrimExpr
&
old_index
=
call
->
args
[
i
+
1
];
PrimExpr
offset
;
if
(
new_buffer
->
strides
.
empty
())
{
offset
=
product
(
buffer
->
shape
);
...
...
@@ -318,5 +333,5 @@ tvm::transform::Pass MultiVersionBuffer() {
TVM_REGISTER_GLOBAL
(
"tl.transform.MultiVersionBuffer"
)
.
set_body_typed
(
MultiVersionBuffer
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/pipeline_planning.cc
View file @
549416f7
...
...
@@ -56,22 +56,23 @@ bool MayConflict(Region region1, Region region2) {
return
true
;
}
}
// namespace
}
// namespace
class
PipelinePlanner
:
public
StmtExprMutator
{
public:
static
Stmt
Substitute
(
const
PrimFunc
&
f
)
{
public:
static
Stmt
Substitute
(
const
PrimFunc
&
f
)
{
PipelinePlanner
substituter
;
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
substituter
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
<<
"Pipeline_Planning: Require the target attribute"
;
ICHECK
(
target
.
defined
())
<<
"Pipeline_Planning: Require the target attribute"
;
substituter
.
target_
=
target
.
value
();
return
substituter
.
VisitStmt
(
f
->
body
);
}
private:
private:
PipelinePlanner
()
=
default
;
struct
PipelineStageInfo
{
...
...
@@ -83,8 +84,10 @@ class PipelinePlanner : public StmtExprMutator {
};
PipelineStageInfo
MakePipelineStageInfo
(
Stmt
stmt
,
int
idx
)
{
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
);
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
);
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
PipelineStageInfo
pinfo
;
pinfo
.
reads
=
std
::
move
(
access
[
0
]);
pinfo
.
writes
=
std
::
move
(
access
[
1
]);
...
...
@@ -93,22 +96,25 @@ class PipelinePlanner : public StmtExprMutator {
// copy stage should only have one reads and one writes
if
(
pinfo
.
reads
.
size
()
==
1
&&
pinfo
.
writes
.
size
()
==
1
)
{
for
(
auto
region
:
pinfo
.
reads
)
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
for
(
auto
region
:
pinfo
.
writes
)
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
}
return
std
::
move
(
pinfo
);
}
Stmt
VisitStmt_
(
const
ForNode
*
loop
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
loop
)
final
{
auto
num_stages_anno
=
loop
->
annotations
.
Get
(
"num_stages"
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
loop
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
loop
);
int
num_stages
=
num_stages_anno
.
as
<
IntImmNode
>
()
->
value
;
Stmt
pipeline_body
{
nullptr
};
if
(
const
auto
*
realize
=
loop
->
body
.
as
<
BlockRealizeNode
>
())
{
const
auto
&
block
=
realize
->
block
;
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
if
(
const
auto
*
realize
=
loop
->
body
.
as
<
BlockRealizeNode
>
())
{
const
auto
&
block
=
realize
->
block
;
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
ICHECK
(
buffer
->
IsInstance
<
BufferNode
>
());
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
...
...
@@ -116,10 +122,10 @@ class PipelinePlanner : public StmtExprMutator {
}
else
{
pipeline_body
=
loop
->
body
;
}
const
SeqStmtNode
*
pipeline_body_seq
=
pipeline_body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline
should be SeqStmt, got "
<<
loop
->
body
->
GetTypeKey
();
const
SeqStmtNode
*
pipeline_body_seq
=
pipeline_body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline "
"
should be SeqStmt, got "
<<
loop
->
body
->
GetTypeKey
();
CHECK
(
num_stages
>=
1
);
CHECK
(
loop
->
kind
==
ForKind
::
kSerial
);
...
...
@@ -130,21 +136,28 @@ class PipelinePlanner : public StmtExprMutator {
}
// analysis use-def chain
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
int
i
=
pinfo
.
original_order
+
1
;
i
<
static_cast
<
int
>
(
pipeline_body_seq
->
size
());
i
++
)
{
if
(
!
pinfo
.
copy_stage
)
continue
;
for
(
const
BufferRegion
&
read
:
pipeline_stage_infos
[
i
].
reads
)
{
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
[
&
](
const
BufferRegion
&
r
)
{
return
r
->
buffer
==
read
->
buffer
&&
MayConflict
(
r
->
region
,
read
->
region
);
})
!=
pinfo
.
writes
.
end
())
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
int
i
=
pinfo
.
original_order
+
1
;
i
<
static_cast
<
int
>
(
pipeline_body_seq
->
size
());
i
++
)
{
if
(
!
pinfo
.
copy_stage
)
continue
;
for
(
const
BufferRegion
&
read
:
pipeline_stage_infos
[
i
].
reads
)
{
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
[
&
](
const
BufferRegion
&
r
)
{
return
r
->
buffer
==
read
->
buffer
&&
MayConflict
(
r
->
region
,
read
->
region
);
})
!=
pinfo
.
writes
.
end
())
{
pinfo
.
last_use_stage
=
std
::
max
(
pinfo
.
last_use_stage
,
i
);
}
}
for
(
const
BufferRegion
&
write
:
pipeline_stage_infos
[
i
].
writes
)
{
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
[
&
](
const
BufferRegion
&
r
)
{
return
r
->
buffer
==
write
->
buffer
&&
MayConflict
(
r
->
region
,
write
->
region
);
})
!=
pinfo
.
writes
.
end
())
{
CHECK
(
false
)
<<
"Can't handle multiple write on overlap buffer region in the pipeline "
for
(
const
BufferRegion
&
write
:
pipeline_stage_infos
[
i
].
writes
)
{
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
[
&
](
const
BufferRegion
&
r
)
{
return
r
->
buffer
==
write
->
buffer
&&
MayConflict
(
r
->
region
,
write
->
region
);
})
!=
pinfo
.
writes
.
end
())
{
CHECK
(
false
)
<<
"Can't handle multiple write on overlap buffer "
"region in the pipeline "
"planning pass: "
<<
pipeline_body_seq
->
seq
[
pinfo
.
original_order
];
}
...
...
@@ -154,28 +167,32 @@ class PipelinePlanner : public StmtExprMutator {
// Making stages and orders
int
order_idx
=
0
;
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
if
(
pinfo
.
copy_stage
&&
pinfo
.
last_use_stage
!=
-
1
)
continue
;
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
if
(
pinfo
.
copy_stage
&&
pinfo
.
last_use_stage
!=
-
1
)
continue
;
pinfo
.
order
=
order_idx
++
;
pinfo
.
stage
=
num_stages
;
for
(
auto
&
pinfo_1
:
pipeline_stage_infos
)
{
if
(
pinfo_1
.
copy_stage
&&
pinfo_1
.
last_use_stage
==
pinfo
.
original_order
)
{
for
(
auto
&
pinfo_1
:
pipeline_stage_infos
)
{
if
(
pinfo_1
.
copy_stage
&&
pinfo_1
.
last_use_stage
==
pinfo
.
original_order
)
{
pinfo_1
.
order
=
order_idx
++
;
pinfo_1
.
stage
=
0
;
}
}
}
ICHECK
(
size_t
(
order_idx
)
==
pipeline_stage_infos
.
size
())
<<
"The number of stages should be equal to the number of pipeline stages. "
<<
"Got "
<<
order_idx
<<
" stages and "
<<
pipeline_stage_infos
.
size
()
<<
" pipeline stages."
;
ICHECK
(
size_t
(
order_idx
)
==
pipeline_stage_infos
.
size
())
<<
"The number of stages should be equal to the number of pipeline "
"stages. "
<<
"Got "
<<
order_idx
<<
" stages and "
<<
pipeline_stage_infos
.
size
()
<<
" pipeline stages."
;
// if all the copy is at the end of the order, we can move these copy to the
beginning of the
// order and shrink the stage offset by 1.
// if all the copy is at the end of the order, we can move these copy to the
//
beginning of the
order and shrink the stage offset by 1.
int
copy_stage_at_end
=
[
&
]()
{
int
copy_stage_cnt
=
0
;
int
copy_order_min
=
pipeline_stage_infos
.
size
();
int
non_copy_order_max
=
0
;
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
if
(
pinfo
.
copy_stage
)
{
copy_stage_cnt
++
;
copy_order_min
=
std
::
min
(
copy_order_min
,
pinfo
.
order
);
...
...
@@ -183,19 +200,22 @@ class PipelinePlanner : public StmtExprMutator {
non_copy_order_max
=
std
::
max
(
non_copy_order_max
,
pinfo
.
order
);
}
}
if
(
copy_order_min
>
non_copy_order_max
)
return
copy_stage_cnt
;
if
(
copy_order_min
>
non_copy_order_max
)
return
copy_stage_cnt
;
return
-
1
;
}();
if
(
copy_stage_at_end
>
0
&&
num_stages
>=
2
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
// move copy to the beginning
pinfo
.
order
=
(
pinfo
.
order
+
copy_stage_at_end
)
%
pipeline_stage_infos
.
size
();
if
(
!
pinfo
.
copy_stage
)
pinfo
.
stage
--
;
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
// move copy to the beginning
pinfo
.
order
=
(
pinfo
.
order
+
copy_stage_at_end
)
%
pipeline_stage_infos
.
size
();
if
(
!
pinfo
.
copy_stage
)
pinfo
.
stage
--
;
}
}
// Finally, make the pipeline annotation
Map
<
String
,
ObjectRef
>
annotations
;
for
(
const
auto
&
[
key
,
value
]
:
loop
->
annotations
)
{
for
(
const
auto
&
[
key
,
value
]
:
loop
->
annotations
)
{
if
(
key
!=
"num_stages"
)
{
annotations
.
Set
(
key
,
value
);
}
...
...
@@ -204,7 +224,7 @@ class PipelinePlanner : public StmtExprMutator {
std
::
vector
<
Integer
>
orders
,
stages
;
orders
.
reserve
(
pipeline_stage_infos
.
size
());
stages
.
reserve
(
pipeline_stage_infos
.
size
());
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
orders
.
push_back
(
pinfo
.
order
);
stages
.
push_back
(
pinfo
.
stage
);
}
...
...
@@ -212,18 +232,19 @@ class PipelinePlanner : public StmtExprMutator {
annotations
.
Set
(
tir
::
attr
::
software_pipeline_stage
,
Array
<
Integer
>
(
stages
));
annotations
.
Set
(
tir
::
attr
::
software_pipeline_order
,
Array
<
Integer
>
(
orders
));
if
(
TargetHasAsyncCopy
(
target_
))
annotations
.
Set
(
tir
::
attr
::
software_pipeline_async_stages
,
Array
<
Integer
>
{
0
});
annotations
.
Set
(
tir
::
attr
::
software_pipeline_async_stages
,
Array
<
Integer
>
{
0
});
return
For
(
loop
->
loop_var
,
loop
->
min
,
loop
->
extent
,
loop
->
kind
,
loop
->
body
,
loop
->
thread_binding
,
annotations
);
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
Block
block
=
Downcast
<
Block
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
}
return
std
::
move
(
block
);
...
...
@@ -236,14 +257,15 @@ class PipelinePlanner : public StmtExprMutator {
tvm
::
transform
::
Pass
PipelinePlanning
()
{
using
namespace
tir
::
transform
;
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
PipelinePlanner
::
Substitute
(
f
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.PipelinePlanning"
,
{});
}
TVM_REGISTER_GLOBAL
(
"tl.transform.PipelinePlanning"
).
set_body_typed
(
PipelinePlanning
);
TVM_REGISTER_GLOBAL
(
"tl.transform.PipelinePlanning"
)
.
set_body_typed
(
PipelinePlanning
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
src/transform/simplify.cc
View file @
549416f7
...
...
@@ -6,15 +6,15 @@
* \brief Remove useless parameters of TL PrimFunc.
*/
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/utils.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include "tir/analysis/control_flow_graph.h"
#include "tir/analysis/var_use_def_analysis.h"
namespace
tvm
{
namespace
tl
{
...
...
@@ -31,19 +31,19 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
TVM_DECLARE_ATTRS
(
SimplifyConfigNode
,
"tl.transform.SimplifyConfig"
)
{
TVM_ATTR_FIELD
(
transitively_prove_inequalities
)
.
describe
(
"If true, simplify conditionals with transitive combinations
of scoped constraints"
)
.
describe
(
"If true, simplify conditionals with transitive combinations "
"
of scoped constraints"
)
.
set_default
(
false
);
TVM_ATTR_FIELD
(
propagate_knowns_to_prove_conditional
)
.
describe
(
"If true, known buffer values are propagated and used to
statically prove conditionals"
)
.
describe
(
"If true, known buffer values are propagated and used to "
"
statically prove conditionals"
)
.
set_default
(
false
);
TVM_ATTR_FIELD
(
propagate_knowns_to_simplify_expressions
)
.
describe
(
"If true, known buffer values are propagated and used to
replace BufferLoad wherever "
"possible"
)
.
describe
(
"If true, known buffer values are propagated and used to "
"
replace BufferLoad wherever "
"possible"
)
.
set_default
(
false
);
TVM_ATTR_FIELD
(
convert_boolean_to_and_of_ors
)
...
...
@@ -51,102 +51,103 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
.
set_default
(
false
);
TVM_ATTR_FIELD
(
apply_constraints_to_boolean_branches
)
.
describe
(
"If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch"
)
.
describe
(
"If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch"
)
.
set_default
(
false
);
}
RewriteSimplifier
::
Extension
GetEnabledExtensions
()
const
{
RewriteSimplifier
::
Extension
GetEnabledExtensions
()
const
{
RewriteSimplifier
::
Extension
flags
=
RewriteSimplifier
::
kNone
;
if
(
transitively_prove_inequalities
)
{
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kTransitivelyProveInequalities
);
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kTransitivelyProveInequalities
);
}
if
(
convert_boolean_to_and_of_ors
)
{
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kConvertBooleanToAndOfOrs
);
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kConvertBooleanToAndOfOrs
);
}
if
(
apply_constraints_to_boolean_branches
)
{
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kApplyConstraintsToBooleanBranches
);
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kApplyConstraintsToBooleanBranches
);
}
return
flags
;
}
};
std
::
unordered_set
<
const
BufferNode
*>
CollectUsedBuffers
(
const
PrimFunc
&
func
)
{
std
::
unordered_set
<
const
BufferNode
*>
CollectUsedBuffers
(
const
PrimFunc
&
func
)
{
struct
Visitor
:
StmtExprVisitor
{
using
StmtExprVisitor
::
VisitExpr_
;
using
StmtExprVisitor
::
VisitStmt_
;
Visitor
(
PrimFunc
func
)
:
func
(
func
)
{}
void
VisitExpr_
(
const
CallNode
*
op
)
override
{
for
(
const
auto
&
arg
:
op
->
args
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
Downcast
<
PrimExpr
>
(
it
.
second
.
get
()
->
data
).
same_as
(
arg
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
void
VisitExpr_
(
const
CallNode
*
op
)
override
{
for
(
const
auto
&
arg
:
op
->
args
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
Downcast
<
PrimExpr
>
(
it
.
second
.
get
()
->
data
).
same_as
(
arg
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
override
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
.
get
()
->
data
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
void
VisitStmt_
(
const
BlockNode
*
op
)
override
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
.
get
()
->
data
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
for
(
const
auto
&
buffer
:
op
->
reads
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
)
)
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
}
for
(
const
auto
&
buffer
:
op
->
reads
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
for
(
const
auto
&
buffer
:
op
->
writes
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
)
)
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
}
for
(
const
auto
&
buffer
:
op
->
writes
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
StmtExprVisitor
::
VisitStmt_
(
op
);
}
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
VisitBuffer
(
const
Buffer
&
buf
)
{
void
VisitBuffer
(
const
Buffer
&
buf
)
{
// Collect buffers that should remain defined
VarUseDefAnalyzer
usage
(
Array
<
Var
>
{});
usage
(
buf
->
data
);
for
(
const
auto
&
dim
:
buf
->
shape
)
{
for
(
const
auto
&
dim
:
buf
->
shape
)
{
usage
(
dim
);
}
for
(
const
auto
&
dim
:
buf
->
strides
)
{
for
(
const
auto
&
dim
:
buf
->
strides
)
{
usage
(
dim
);
}
usage
(
buf
->
elem_offset
);
for
(
const
auto
&
buffer
:
usage
.
buffer_use_count_
)
{
for
(
const
auto
&
buffer
:
usage
.
buffer_use_count_
)
{
if
(
buffer
.
second
>=
1
)
{
used_in_buffer_def_
.
insert
(
buffer
.
first
);
used_in_buffer_def_
.
insert
(
buffer
.
first
);
}
}
for
(
const
auto
&
buffer
:
usage
.
undefined_buffers_
)
{
for
(
const
auto
&
buffer
:
usage
.
undefined_buffers_
)
{
used_in_buffer_def_
.
insert
(
buffer
.
get
());
}
}
PrimFunc
func
;
std
::
unordered_set
<
const
BufferNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
BufferNode
*>
used_in_buffer_def_
;
};
Visitor
visitor
(
func
);
...
...
@@ -154,41 +155,42 @@ std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) {
return
visitor
.
used_in_buffer_def_
;
}
/* \brief Utility function to collect vars that should be retained. Used in Letstmt Only
*/
std
::
unordered_set
<
const
VarNode
*>
CollectVarsUsedInBufferDefinition
(
const
Stmt
&
stmt
)
{
/* \brief Utility function to collect vars that should be retained. Used in
* Letstmt Only
*/
std
::
unordered_set
<
const
VarNode
*>
CollectVarsUsedInBufferDefinition
(
const
Stmt
&
stmt
)
{
struct
Visitor
:
StmtExprVisitor
{
using
StmtExprVisitor
::
VisitExpr_
;
using
StmtExprVisitor
::
VisitStmt_
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitStmt_
(
op
);
}
void
VisitBuffer
(
const
Buffer
&
buf
)
{
void
VisitBuffer
(
const
Buffer
&
buf
)
{
// Collect variables that should remain defined
VarUseDefAnalyzer
usage
(
Array
<
Var
>
{});
usage
(
buf
->
data
);
for
(
const
auto
&
dim
:
buf
->
shape
)
{
for
(
const
auto
&
dim
:
buf
->
shape
)
{
usage
(
dim
);
}
for
(
const
auto
&
dim
:
buf
->
strides
)
{
for
(
const
auto
&
dim
:
buf
->
strides
)
{
usage
(
dim
);
}
usage
(
buf
->
elem_offset
);
// Track for use in LetStmtNode mutator
for
(
const
auto
&
var
:
usage
.
undefined_
)
{
for
(
const
auto
&
var
:
usage
.
undefined_
)
{
used_in_buffer_def_
.
insert
(
var
.
get
());
}
}
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
};
Visitor
visitor
;
...
...
@@ -197,20 +199,21 @@ std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt&
}
class
SimplifyConfig
:
public
Attrs
{
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS
(
SimplifyConfig
,
Attrs
,
SimplifyConfigNode
);
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS
(
SimplifyConfig
,
Attrs
,
SimplifyConfigNode
);
};
TVM_REGISTER_NODE_TYPE
(
SimplifyConfigNode
);
TVM_REGISTER_PASS_CONFIG_OPTION
(
"tl.Simplify"
,
SimplifyConfig
);
class
StmtSimplifier
:
public
IRMutatorWithAnalyzer
{
public:
static
PrimFunc
Apply
(
PrimFunc
func
,
Analyzer
*
analyzer
,
public:
static
PrimFunc
Apply
(
PrimFunc
func
,
Analyzer
*
analyzer
,
Optional
<
SimplifyConfig
>
config_opt
=
NullOpt
)
{
auto
config
=
config_opt
.
value_or
(
AttrsWithDefaultValues
<
SimplifyConfig
>
());
analyzer
->
rewrite_simplify
.
SetEnabledExtensions
(
config
->
GetEnabledExtensions
());
analyzer
->
rewrite_simplify
.
SetEnabledExtensions
(
config
->
GetEnabledExtensions
());
std
::
optional
<
ControlFlowGraph
>
touch_pattern
=
std
::
nullopt
;
if
(
config
->
propagate_knowns_to_prove_conditional
||
...
...
@@ -218,7 +221,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
touch_pattern
=
ControlFlowGraph
(
func
->
body
);
}
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
=
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
=
CollectVarsUsedInBufferDefinition
(
func
->
body
);
StmtSimplifier
simplifier
(
analyzer
,
config
,
std
::
move
(
touch_pattern
),
std
::
move
(
used_in_buffer_def
));
...
...
@@ -232,41 +235,44 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Array
<
Var
>
new_params
;
Map
<
Var
,
Buffer
>
new_buffer_map
;
// Check whether each buffer is used
for
(
const
auto
&
var
:
func
->
params
)
{
if
(
func
->
buffer_map
.
find
(
var
)
!=
func
->
buffer_map
.
end
())
{
if
(
simplifier
.
used_buffers_
.
find
(
func
->
buffer_map
[
var
].
get
())
!=
simplifier
.
used_buffers_
.
end
())
{
new_params
.
push_back
(
var
);
new_buffer_map
.
Set
(
var
,
func
->
buffer_map
[
var
]
);
}
else
{
param_updated
=
true
;
}
for
(
const
auto
&
var
:
func
->
params
)
{
if
(
func
->
buffer_map
.
find
(
var
)
!=
func
->
buffer_map
.
end
())
{
if
(
simplifier
.
used_buffers_
.
find
(
func
->
buffer_map
[
var
].
get
())
!=
simplifier
.
used_buffers_
.
end
())
{
new_params
.
push_back
(
var
);
new_buffer_map
.
Set
(
var
,
func
->
buffer_map
[
var
]);
}
else
{
param_updated
=
true
;
}
}
}
// return func;
if
(
param_updated
)
{
return
PrimFunc
(
new_params
,
func
.
CopyOnWrite
()
->
body
,
func
->
ret_type
,
new_buffer_map
,
func
->
attrs
,
func
->
span
);
return
PrimFunc
(
new_params
,
func
.
CopyOnWrite
()
->
body
,
func
->
ret_type
,
new_buffer_map
,
func
->
attrs
,
func
->
span
);
}
else
{
return
func
;
return
func
;
}
}
private:
explicit
StmtSimplifier
(
Analyzer
*
analyzer
,
SimplifyConfig
config
,
std
::
optional
<
ControlFlowGraph
>
touch_pattern
,
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
)
:
IRMutatorWithAnalyzer
(
analyzer
),
config_
(
config
),
touch_pattern_
(
touch_pattern
),
used_in_buffer_def_
(
used_in_buffer_def
)
{
}
private:
explicit
StmtSimplifier
(
Analyzer
*
analyzer
,
SimplifyConfig
config
,
std
::
optional
<
ControlFlowGraph
>
touch_pattern
,
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
)
:
IRMutatorWithAnalyzer
(
analyzer
),
config_
(
config
),
touch_pattern_
(
touch_pattern
),
used_in_buffer_def_
(
used_in_buffer_def
)
{
}
using
Parent
=
IRMutatorWithAnalyzer
;
using
Parent
::
VisitExpr_
;
using
Parent
::
VisitStmt
;
using
Parent
::
VisitStmt_
;
PrimExpr
VisitExpr
(
const
PrimExpr
&
expr
)
final
{
PrimExpr
VisitExpr
(
const
PrimExpr
&
expr
)
final
{
if
(
config_
->
propagate_knowns_to_simplify_expressions
)
{
return
touch_pattern_
->
SimplifyInContext
(
expr
,
current_stmt_
.
value
(),
analyzer_
);
return
touch_pattern_
->
SimplifyInContext
(
expr
,
current_stmt_
.
value
(),
analyzer_
);
}
else
{
return
analyzer_
->
Simplify
(
expr
);
}
...
...
@@ -274,7 +280,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Stmt
Simplify
(
Stmt
stmt
)
{
return
operator
()(
std
::
move
(
stmt
));
}
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
override
{
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
override
{
Optional
<
Stmt
>
cache
=
this
->
current_stmt_
;
this
->
current_stmt_
=
stmt
;
Stmt
output
=
Parent
::
VisitStmt
(
stmt
);
...
...
@@ -282,23 +288,28 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return
output
;
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
analyzer_
->
Bind
(
op
->
loop_var
,
Range
::
FromMinExtent
(
op
->
min
,
op
->
extent
));
With
<
ConstraintContext
>
ctx1
(
analyzer_
,
op
->
loop_var
>=
op
->
min
);
With
<
ConstraintContext
>
ctx2
(
analyzer_
,
op
->
loop_var
<
op
->
min
+
op
->
extent
);
With
<
ConstraintContext
>
ctx2
(
analyzer_
,
op
->
loop_var
<
op
->
min
+
op
->
extent
);
return
Parent
::
VisitStmt_
(
op
);
}
bool
CanInlineLetStmt
(
const
LetStmtNode
*
op
)
{
if
(
is_const_number
(
op
->
value
))
return
true
;
if
(
op
->
value
.
as
<
VarNode
>
())
return
true
;
bool
CanInlineLetStmt
(
const
LetStmtNode
*
op
)
{
if
(
is_const_number
(
op
->
value
))
return
true
;
if
(
op
->
value
.
as
<
VarNode
>
())
return
true
;
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be index).
if
(
!
op
->
value
.
dtype
().
is_int
())
return
false
;
// attempt to inline as much as possible if the value integer type(can be
// index).
if
(
!
op
->
value
.
dtype
().
is_int
())
return
false
;
return
SideEffect
(
op
->
value
)
<=
CallEffectKind
::
kPure
;
}
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
override
{
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
override
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
bool
can_inline
=
CanInlineLetStmt
(
op
);
if
(
can_inline
)
{
...
...
@@ -339,7 +350,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
override
{
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
override
{
if
(
Optional
<
Bool
>
cond
=
ProveCondition
(
op
->
condition
))
{
if
(
cond
.
value
()
->
value
)
{
return
this
->
VisitStmt
(
op
->
then_case
);
...
...
@@ -353,7 +364,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
override
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
override
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
if
(
Optional
<
Bool
>
cond
=
ProveCondition
(
op
->
args
[
0
]))
{
if
(
cond
.
value
()
->
value
)
{
...
...
@@ -366,26 +377,27 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return
Parent
::
VisitExpr_
(
op
);
}
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
override
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
override
{
used_vars_
.
insert
(
op
);
return
Parent
::
VisitExpr_
(
op
);
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
auto
buffer
=
op
->
buffer
.
get
();
if
(
used_buffers_
.
find
(
buffer
)
==
used_buffers_
.
end
())
{
used_buffers_
.
insert
(
buffer
);
used_buffers_
.
insert
(
buffer
);
}
return
Parent
::
VisitExpr_
(
op
);
}
// eliminate useless stores
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
BufferStore
store
=
Downcast
<
BufferStore
>
(
Parent
::
VisitStmt_
(
op
));
if
(
const
BufferLoadNode
*
load
=
store
->
value
.
as
<
BufferLoadNode
>
())
{
if
(
const
BufferLoadNode
*
load
=
store
->
value
.
as
<
BufferLoadNode
>
())
{
if
(
load
->
buffer
->
data
.
same_as
(
store
->
buffer
->
data
)
&&
ArrayDeepEqual
(
load
->
indices
,
store
->
indices
)
&&
tir
::
ExprDeepEqual
()(
load
->
buffer
->
elem_offset
,
store
->
buffer
->
elem_offset
)
&&
tir
::
ExprDeepEqual
()(
load
->
buffer
->
elem_offset
,
store
->
buffer
->
elem_offset
)
&&
ArrayDeepEqual
(
load
->
buffer
->
shape
,
store
->
buffer
->
shape
)
&&
ArrayDeepEqual
(
load
->
buffer
->
strides
,
store
->
buffer
->
strides
))
{
return
Evaluate
(
0
);
...
...
@@ -393,13 +405,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
auto
buffer
=
op
->
buffer
.
get
();
if
(
used_buffers_
.
find
(
buffer
)
==
used_buffers_
.
end
())
{
used_buffers_
.
insert
(
buffer
);
used_buffers_
.
insert
(
buffer
);
}
return
std
::
move
(
store
);
}
private:
bool
ArrayDeepEqual
(
const
Array
<
PrimExpr
>
&
lhs
,
const
Array
<
PrimExpr
>
&
rhs
)
{
private:
bool
ArrayDeepEqual
(
const
Array
<
PrimExpr
>
&
lhs
,
const
Array
<
PrimExpr
>
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
...
...
@@ -420,11 +432,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
condition
=
Substitute
(
condition
,
non_inlined_bindings_
);
if
(
config_
->
propagate_knowns_to_prove_conditional
)
{
ICHECK
(
touch_pattern_
.
has_value
());
condition
=
touch_pattern_
->
SimplifyInContext
(
condition
,
current_stmt_
.
value
(),
analyzer_
);
condition
=
touch_pattern_
->
SimplifyInContext
(
condition
,
current_stmt_
.
value
(),
analyzer_
);
}
else
{
condition
=
analyzer_
->
Simplify
(
condition
);
}
if
(
const
int64_t
*
as_int
=
as_const_int
(
condition
))
{
if
(
const
int64_t
*
as_int
=
as_const_int
(
condition
))
{
return
Bool
(
*
as_int
);
}
else
{
return
NullOpt
;
...
...
@@ -436,21 +449,20 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Map
<
Var
,
PrimExpr
>
non_inlined_bindings_
;
Optional
<
Stmt
>
current_stmt_
{
NullOpt
};
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
VarNode
*>
used_vars_
;
std
::
unordered_set
<
const
BufferNode
*>
used_buffers_
;
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
VarNode
*>
used_vars_
;
std
::
unordered_set
<
const
BufferNode
*>
used_buffers_
;
};
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
Simplify
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
arith
::
Analyzer
analyzer
;
auto
cfg
=
ctx
->
GetConfig
<
SimplifyConfig
>
(
"tl.Simplify"
);
return
StmtSimplifier
::
Apply
(
f
,
&
analyzer
,
cfg
);
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.Simplify"
,
{});
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
arith
::
Analyzer
analyzer
;
auto
cfg
=
ctx
->
GetConfig
<
SimplifyConfig
>
(
"tl.Simplify"
);
return
StmtSimplifier
::
Apply
(
f
,
&
analyzer
,
cfg
);
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.Simplify"
,
{});
}
TVM_REGISTER_GLOBAL
(
"tl.transform.Simplify"
).
set_body_typed
(
Simplify
);
...
...
src/transform/thread_partial_sync.cc
View file @
549416f7
...
...
@@ -25,26 +25,28 @@ namespace tl {
using
namespace
tir
;
class
ThreadPartialSyncPlanner
:
public
StorageAccessVisitor
{
public:
explicit
ThreadPartialSyncPlanner
(
StorageScope
sync_scope
)
:
sync_scope_
(
sync_scope
)
{}
public:
explicit
ThreadPartialSyncPlanner
(
StorageScope
sync_scope
)
:
sync_scope_
(
sync_scope
)
{}
// The syncs inserted before each statement
std
::
unordered_set
<
const
Object
*>
syncs_inserted_
;
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs_inserted_
;
std
::
unordered_set
<
const
Object
*>
syncs_inserted_
;
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs_inserted_
;
protected:
bool
Enabled
(
const
VarNode
*
buf
,
const
StorageScope
&
scope
)
const
final
{
protected:
bool
Enabled
(
const
VarNode
*
buf
,
const
StorageScope
&
scope
)
const
final
{
return
in_device_env
()
&&
scope
==
sync_scope_
;
}
// Plan the sync
std
::
vector
<
AccessEntry
>
Summarize
(
std
::
vector
<
StmtEntry
>
seq
,
const
ForNode
*
loop
)
final
{
std
::
vector
<
AccessEntry
>
Summarize
(
std
::
vector
<
StmtEntry
>
seq
,
const
ForNode
*
loop
)
final
{
// Redirect all "shared.dyn" buffer access to the same buffer var
// so that the accesses can be planned together.
Var
shared_dyn_buf
;
for
(
StmtEntry
&
entry
:
seq
)
{
for
(
AccessEntry
&
access
:
entry
.
access
)
{
if
(
access
.
scope
.
rank
==
StorageRank
::
kShared
&&
access
.
scope
.
tag
==
".dyn"
&&
access
.
buffer
.
defined
())
{
for
(
StmtEntry
&
entry
:
seq
)
{
for
(
AccessEntry
&
access
:
entry
.
access
)
{
if
(
access
.
scope
.
rank
==
StorageRank
::
kShared
&&
access
.
scope
.
tag
==
".dyn"
&&
access
.
buffer
.
defined
())
{
if
(
!
shared_dyn_buf
.
defined
())
{
shared_dyn_buf
=
access
.
buffer
;
}
else
{
...
...
@@ -60,7 +62,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// if it is a loop, rotate two times to consider effect of loop.
// simulation based approach to find dependencies
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
const
StmtEntry
&
s
=
seq
[
i
];
// check if sync before statement is needed.
bool
sync_before_stmt
=
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
);
// Apply the syncs added already.
...
...
@@ -68,7 +70,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
reads
.
clear
();
writes
.
clear
();
}
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
FindConflict
(
writes
,
acc
,
false
))
{
sync_before_stmt
=
true
;
...
...
@@ -90,7 +92,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
writes
.
clear
();
}
// Add the read/write of current statement
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
reads
.
push_back
(
acc
);
}
else
if
(
acc
.
type
==
kWrite
)
{
...
...
@@ -106,11 +108,13 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
if
(
loop
!=
nullptr
)
{
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
if
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
)
break
;
if
(
reads
.
empty
()
&&
writes
.
empty
())
break
;
const
StmtEntry
&
s
=
seq
[
i
];
if
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
)
break
;
if
(
reads
.
empty
()
&&
writes
.
empty
())
break
;
bool
sync_before_stmt
=
false
;
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
FindConflict
(
writes
,
acc
,
true
))
{
sync_before_stmt
=
true
;
...
...
@@ -141,7 +145,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
esync
.
type
=
kSync
;
esync
.
scope
=
sync_scope_
;
for
(
const
StmtEntry
&
s
:
seq
)
{
for
(
const
StmtEntry
&
s
:
seq
)
{
if
(
syncs_inserted_
.
count
(
s
.
stmt
))
{
if
(
sync_count
!=
0
)
{
tail
.
clear
();
...
...
@@ -150,7 +154,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
++
sync_count
;
}
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kSync
)
{
if
(
sync_count
!=
0
)
{
tail
.
clear
();
...
...
@@ -170,18 +174,18 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
head
.
insert
(
head
.
end
(),
tail
.
begin
(),
tail
.
end
());
if
(
loop
!=
nullptr
)
{
// clear double buffer flag after a loop is finished.
for
(
AccessEntry
&
e
:
head
)
{
for
(
AccessEntry
&
e
:
head
)
{
e
.
double_buffer_write
=
false
;
}
}
return
head
;
}
private:
private:
// find conflicting entry in vec.
bool
FindConflict
(
const
std
::
vector
<
AccessEntry
>
&
prev
,
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
for
(
const
AccessEntry
&
x
:
prev
)
{
bool
FindConflict
(
const
std
::
vector
<
AccessEntry
>
&
prev
,
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
for
(
const
AccessEntry
&
x
:
prev
)
{
if
(
FindConflict
(
x
,
curr
,
loop_carry
))
{
return
true
;
}
...
...
@@ -189,7 +193,8 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return
false
;
}
bool
FindConflict
(
const
AccessEntry
&
prev
,
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
bool
FindConflict
(
const
AccessEntry
&
prev
,
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
// Access to different buffers does not conflict.
if
(
!
prev
.
buffer
.
same_as
(
curr
.
buffer
))
{
return
false
;
...
...
@@ -202,21 +207,21 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// Even if access has the same index, those indices need to
// depend on the innermost thread id to avoid race condition
bool
depends_on_thread_index
=
true
;
const
VarNode
*
thread_index_var
=
nullptr
;
const
VarNode
*
thread_index_var
=
nullptr
;
if
(
!
curr
.
threads
.
empty
())
{
thread_index_var
=
curr
.
threads
.
back
()
->
var
.
get
();
}
for
(
size_t
i
=
0
;
i
<
prev
.
touched
.
size
();
i
++
)
{
const
auto
&
prev_intset
=
prev
.
touched
[
i
];
const
auto
&
curr_intset
=
curr
.
touched
[
i
];
const
auto
&
prev_intset
=
prev
.
touched
[
i
];
const
auto
&
curr_intset
=
curr
.
touched
[
i
];
if
(
prev_intset
.
IsSinglePoint
()
&&
curr_intset
.
IsSinglePoint
())
{
PrimExpr
prev_index
=
prev_intset
.
PointValue
();
PrimExpr
curr_index
=
curr_intset
.
PointValue
();
has_same_index
=
ExprDeepEqual
()(
prev_index
,
curr_index
);
if
(
thread_index_var
!=
nullptr
)
{
auto
f_uses_thread_index
=
[
=
](
const
tvm
::
tir
::
VarNode
*
parameter
)
{
auto
f_uses_thread_index
=
[
=
](
const
tvm
::
tir
::
VarNode
*
parameter
)
{
return
parameter
==
thread_index_var
;
};
depends_on_thread_index
=
depends_on_thread_index
&&
...
...
@@ -246,7 +251,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return
true
;
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
"kWarpSpecializationScope"
)
{
IfThenElse
body
=
Downcast
<
IfThenElse
>
(
op
->
body
);
auto
partitions
=
Downcast
<
Array
<
IntImm
>>
(
op
->
node
);
...
...
@@ -273,27 +278,31 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
}
void
insert_syncs
(
const
Object
*
obj
)
{
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
if
(
syncs_inserted_
.
count
(
obj
))
return
;
void
insert_syncs
(
const
Object
*
obj
)
{
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside
// condition";
if
(
syncs_inserted_
.
count
(
obj
))
return
;
if
(
num_partial_threads_
.
defined
())
{
syncs_inserted_
.
insert
(
obj
);
partial_syncs_inserted_
[
obj
]
=
static_cast
<
int
>
(
num_partial_threads_
.
value
()
->
value
);
partial_syncs_inserted_
[
obj
]
=
static_cast
<
int
>
(
num_partial_threads_
.
value
()
->
value
);
}
else
{
syncs_inserted_
.
insert
(
obj
);
}
}
private:
private:
Optional
<
IntImm
>
num_partial_threads_
;
// synchronization scope
StorageScope
sync_scope_
;
};
// There are cases where necessary syncthreads is not inserted by ThreadPartialSyncInserter.
// For example, syncthreads is needed after async_wait_queue in the second loop below,
// but since ThreadPartialSyncInserter is not aware of the asynchronous semantics, it cannot tell
// that the syncthreads is needed there.
// There are cases where necessary syncthreads is not inserted by
// ThreadPartialSyncInserter. For example, syncthreads is needed after
// async_wait_queue in the second loop below, but since
// ThreadPartialSyncInserter is not aware of the asynchronous semantics, it
// cannot tell that the syncthreads is needed there.
//
// // Pipeline prologue
// for i in range(125):
...
...
@@ -307,21 +316,23 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// async_wait_queue(0, 2 - i):
// local[...] = shared[(i + 125) % 4]
class
ThreadPartialSyncInserter
:
public
StmtExprMutator
{
public:
ThreadPartialSyncInserter
(
StorageScope
sync_scope
,
const
std
::
unordered_set
<
const
Object
*>&
syncs
,
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs
)
public:
ThreadPartialSyncInserter
(
StorageScope
sync_scope
,
const
std
::
unordered_set
<
const
Object
*>
&
syncs
,
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs
)
:
sync_scope_
(
sync_scope
),
syncs_
(
syncs
),
partial_syncs_
(
partial_syncs
)
{}
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
final
{
if
(
syncs_
.
size
()
==
0
)
return
stmt
;
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
final
{
if
(
syncs_
.
size
()
==
0
)
return
stmt
;
if
(
syncs_
.
count
(
stmt
.
get
()))
{
Stmt
barrier
;
if
(
partial_syncs_
.
count
(
stmt
.
get
()))
{
auto
iter
=
partial_syncs_
.
find
(
stmt
.
get
());
ICHECK
(
sync_scope_
.
rank
==
StorageRank
::
kShared
);
barrier
=
Evaluate
(
Call
(
DataType
::
Int
(
32
),
tl
::
SyncThreadsPartialOp
(),
{
iter
->
second
}));
barrier
=
Evaluate
(
Call
(
DataType
::
Int
(
32
),
tl
::
SyncThreadsPartialOp
(),
{
iter
->
second
}));
}
else
{
return
StmtExprMutator
::
VisitStmt
(
stmt
);
}
...
...
@@ -334,11 +345,11 @@ class ThreadPartialSyncInserter : public StmtExprMutator {
}
}
private:
private:
// data structure.
StorageScope
sync_scope_
;
const
std
::
unordered_set
<
const
Object
*>
&
syncs_
;
const
std
::
unordered_map
<
const
Object
*
,
int
>
&
partial_syncs_
;
const
std
::
unordered_set
<
const
Object
*>
&
syncs_
;
const
std
::
unordered_map
<
const
Object
*
,
int
>
&
partial_syncs_
;
};
Stmt
ThreadPartialSync
(
Stmt
stmt
,
std
::
string
storage_scope
)
{
...
...
@@ -346,7 +357,8 @@ Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
ThreadPartialSyncPlanner
planner
(
sync_scope
);
planner
(
stmt
);
return
ThreadPartialSyncInserter
(
sync_scope
,
planner
.
syncs_inserted_
,
planner
.
partial_syncs_inserted_
)(
std
::
move
(
stmt
));
planner
.
partial_syncs_inserted_
)(
std
::
move
(
stmt
));
}
using
namespace
tir
::
transform
;
...
...
@@ -355,15 +367,16 @@ namespace transform {
Pass
ThreadPartialSync
(
String
storage_scope
)
{
auto
pass_func
=
[
storage_scope
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
*
n
=
f
.
CopyOnWrite
();
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
tl
::
ThreadPartialSync
(
std
::
move
(
n
->
body
),
storage_scope
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ThreadPartialSync"
,
{});
}
TVM_REGISTER_GLOBAL
(
"tl.transform.ThreadPartialSync"
).
set_body_typed
(
ThreadPartialSync
);
TVM_REGISTER_GLOBAL
(
"tl.transform.ThreadPartialSync"
)
.
set_body_typed
(
ThreadPartialSync
);
}
// namespace transform
}
// namespace t
ir
}
// namespace tvm
}
// namespace transform
}
// namespace t
l
}
// namespace tvm
src/transform/warp_specialized_rewriter.cc
View file @
549416f7
...
...
@@ -38,22 +38,23 @@ using namespace tir;
enum
class
Role
{
kConsumer
,
kProducer
,
kBoth
};
class
WarpSpecializedRoleMarker
:
public
StmtVisitor
{
public:
public:
WarpSpecializedRoleMarker
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
)
:
buffer_data_to_buffer_
(
buffer_data_to_buffer
)
{}
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
auto
it
=
map_
.
find
(
stmt
);
ICHECK
(
it
!=
map_
.
end
());
return
it
->
second
;
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
Role
role
=
Role
::
kConsumer
;
if
(
auto
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
role
=
Role
::
kProducer
;
has_bulk_copy_
=
true
;
}
...
...
@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
if
(
!
is_shared_store
)
{
SetRole
(
op
,
Role
::
kConsumer
);
return
;
...
...
@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
break
;
}
}
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
seq
[
0
]);
for
(
auto
stmt
:
op
->
seq
)
{
...
...
@@ -96,41 +99,41 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
then_case
);
if
(
op
->
else_case
.
defined
())
{
auto
role_else
=
GetRole
(
op
->
else_case
.
value
());
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
}
SetRole
(
op
,
role
);
}
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
block
));
}
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
body
));
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
bool
HasProducer
()
{
return
has_simt_copy_
||
has_bulk_copy_
;
}
bool
HasSimtCopy
()
{
return
has_simt_copy_
;
}
private:
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
private:
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
bool
has_simt_copy_
=
false
;
bool
has_bulk_copy_
=
false
;
};
...
...
@@ -140,23 +143,26 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
}
static
Stmt
makeExpectTX
(
PrimExpr
barrier_id
,
PrimExpr
bytes
)
{
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierExpectTX
(),
{
makeGetBarrier
(
barrier_id
),
bytes
});
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierExpectTX
(),
{
makeGetBarrier
(
barrier_id
),
bytes
});
return
Evaluate
(
call
);
}
static
Stmt
makeArriveBarrier
(
PrimExpr
barrier_id
)
{
auto
call
=
Call
(
DataType
::
Handle
(),
builtin
::
ptx_arrive_barrier
(),
{
makeGetBarrier
(
barrier_id
)});
auto
call
=
Call
(
DataType
::
Handle
(),
builtin
::
ptx_arrive_barrier
(),
{
makeGetBarrier
(
barrier_id
)});
return
Evaluate
(
call
);
}
static
Stmt
makeCpAsyncBarrier
(
PrimExpr
barrier_id
)
{
auto
call
=
Call
(
DataType
::
Handle
(),
builtin
::
ptx_cp_async_barrier
(),
{
makeGetBarrier
(
barrier_id
)});
auto
call
=
Call
(
DataType
::
Handle
(),
builtin
::
ptx_cp_async_barrier
(),
{
makeGetBarrier
(
barrier_id
)});
return
Evaluate
(
call
);
}
static
Stmt
makeParityWait
(
PrimExpr
barrier_id
,
PrimExpr
parity
)
{
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierWaitParity
(),
{
makeGetBarrier
(
barrier_id
),
parity
});
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierWaitParity
(),
{
makeGetBarrier
(
barrier_id
),
parity
});
return
Evaluate
(
call
);
}
...
...
@@ -177,7 +183,7 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
// }
class
ProducerTraitsCollector
:
public
StmtExprVisitor
{
public:
public:
ProducerTraitsCollector
()
{
Clear
();
}
void
Clear
()
{
...
...
@@ -192,8 +198,8 @@ class ProducerTraitsCollector : public StmtExprVisitor {
PrimExpr
BulkCopyBytes
()
{
return
bulk_copy_bytes
;
}
private:
void
VisitExpr_
(
const
CallNode
*
call
)
final
{
private:
void
VisitExpr_
(
const
CallNode
*
call
)
final
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
Call
access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
ICHECK
(
access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
...
...
@@ -203,14 +209,14 @@ class ProducerTraitsCollector : public StmtExprVisitor {
StmtExprVisitor
::
VisitExpr_
(
call
);
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
PrimExpr
old_loop_evtents
=
loop_extents
;
loop_extents
*=
op
->
extent
;
StmtExprVisitor
::
VisitStmt_
(
op
);
loop_extents
=
old_loop_evtents
;
}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
has_simt_copy
=
true
;
StmtExprVisitor
::
VisitExpr_
(
op
);
}
...
...
@@ -222,15 +228,15 @@ class ProducerTraitsCollector : public StmtExprVisitor {
// Rewrite the producer Stmt to use the correct barrier index
class
MbarrierRewriter
:
public
StmtExprMutator
{
public:
public:
static
Stmt
Rewrite
(
Stmt
stmt
,
PrimExpr
barrier_id
)
{
MbarrierRewriter
rewriter
;
rewriter
.
producer_barrier_idx_
=
barrier_id
;
return
rewriter
(
stmt
);
}
private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
auto
call
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
Call
access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
...
...
@@ -242,19 +248,18 @@ class MbarrierRewriter : public StmtExprMutator {
PrimExpr
producer_barrier_idx_
;
};
class
ThreadIdxRewriter
:
public
StmtExprMutator
{
public:
public:
static
Stmt
Rewrite
(
Stmt
stmt
,
Var
thread_var
,
PrimExpr
replaced
)
{
auto
rewriter
=
ThreadIdxRewriter
(
thread_var
,
replaced
);
return
rewriter
(
stmt
);
}
private:
private:
ThreadIdxRewriter
(
Var
thread_var
,
PrimExpr
replaced
)
:
thread_var_
(
thread_var
),
replaced_
(
replaced
)
{}
PrimExpr
VisitExpr_
(
const
VarNode
*
var
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
var
)
final
{
if
(
var
==
thread_var_
.
get
())
{
return
replaced_
;
}
else
{
...
...
@@ -266,9 +271,12 @@ class ThreadIdxRewriter : public StmtExprMutator {
PrimExpr
replaced_
;
};
Block
MakeGroupBlock
(
const
Stmt
&
stmt
,
const
Map
<
String
,
ObjectRef
>&
annotations
)
{
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
,
/*init=*/
{},
/*alloc_buffers=*/
{},
/*match_buffers=*/
{},
/*annotations=*/
annotations
);
Block
MakeGroupBlock
(
const
Stmt
&
stmt
,
const
Map
<
String
,
ObjectRef
>
&
annotations
)
{
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
,
/*init=*/
{},
/*alloc_buffers=*/
{},
/*match_buffers=*/
{},
/*annotations=*/
annotations
);
return
block
;
}
...
...
@@ -280,11 +288,8 @@ struct PipelineInfo {
std
::
vector
<
OpInfo
>
op_infos
;
PipelineInfo
()
=
default
;
PipelineInfo
(
Array
<
Array
<
Integer
>>
group_info
,
Array
<
Integer
>
order_info
,
Array
<
Integer
>
stage_info
)
{
PipelineInfo
(
Array
<
Array
<
Integer
>>
group_info
,
Array
<
Integer
>
order_info
,
Array
<
Integer
>
stage_info
)
{
int
n
=
static_cast
<
int
>
(
group_info
.
size
());
ICHECK
(
n
==
static_cast
<
int
>
(
order_info
.
size
()));
ICHECK
(
n
==
static_cast
<
int
>
(
stage_info
.
size
()));
...
...
@@ -301,7 +306,7 @@ struct PipelineInfo {
}
}
PipelineInfo
(
const
PipelineInfo
&
other
)
{
PipelineInfo
(
const
PipelineInfo
&
other
)
{
for
(
auto
op_info
:
other
.
op_infos
)
{
op_infos
.
push_back
(
op_info
);
}
...
...
@@ -364,18 +369,19 @@ struct PipelineInfo {
void
PrintPipelineInfo
()
{
std
::
cout
<<
"Print op_infos:"
<<
std
::
endl
;
for
(
size_t
i
=
0
;
i
<
op_infos
.
size
();
i
++
)
{
std
::
cout
<<
i
<<
" "
<<
op_infos
[
i
].
group_size
<<
" "
<<
op_infos
[
i
].
order
<<
" "
<<
op_infos
[
i
].
stage
<<
std
::
endl
;
std
::
cout
<<
i
<<
" "
<<
op_infos
[
i
].
group_size
<<
" "
<<
op_infos
[
i
].
order
<<
" "
<<
op_infos
[
i
].
stage
<<
std
::
endl
;
}
std
::
cout
<<
"End of print"
<<
std
::
endl
;
}
};
class
GroupOpRewriter
:
public
StmtExprMutator
{
public:
public:
GroupOpRewriter
(
PipelineInfo
pipeline_info
)
:
pipeline_info_
(
pipeline_info
)
{}
private:
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
private:
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Map
<
String
,
ObjectRef
>
annotations
;
annotations
.
Set
(
String
(
"stmt_group"
),
Integer
(
1
));
auto
original_node
=
(
op
->
body
).
as
<
SeqStmtNode
>
();
...
...
@@ -385,19 +391,24 @@ class GroupOpRewriter : public StmtExprMutator {
Array
<
Stmt
>
new_body
;
int
cur_id
=
0
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
.
size
());
i
++
)
{
if
(
pipeline_info_
.
op_infos
[
i
].
group_size
==
0
)
continue
;
if
(
pipeline_info_
.
op_infos
[
i
].
group_size
==
0
)
continue
;
Array
<
Stmt
>
block_stmt
;
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
[
i
].
group_size
);
j
++
)
{
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
[
i
].
group_size
);
j
++
)
{
// ICHECK(group_info_[i][j].as<IntImmNode>());
// int index = static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
// int index =
// static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
ICHECK
(
original_node
->
seq
[
cur_id
].
as
<
BlockNode
>
());
auto
block
=
original_node
->
seq
[
cur_id
].
as
<
BlockNode
>
();
// TODO: handle nested seqstmt
block_stmt
.
push_back
(
block
->
body
);
cur_id
++
;
}
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
}
Array
<
Integer
>
order_anno
;
Array
<
Integer
>
stage_anno
;
...
...
@@ -409,24 +420,26 @@ class GroupOpRewriter : public StmtExprMutator {
for_annotations
.
erase
(
"tl_pipeline_group"
);
for_annotations
.
Set
(
"software_pipeline_order"
,
order_anno
);
for_annotations
.
Set
(
"software_pipeline_stage"
,
stage_anno
);
For
new_for
=
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
new_body
.
size
()
==
1
?
new_body
[
0
]
:
SeqStmt
(
std
::
move
(
new_body
)),
op
->
thread_binding
,
for_annotations
);
For
new_for
=
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
new_body
.
size
()
==
1
?
new_body
[
0
]
:
SeqStmt
(
std
::
move
(
new_body
)),
op
->
thread_binding
,
for_annotations
);
return
new_for
;
}
PipelineInfo
pipeline_info_
;
};
class
WSCodeEmitter
:
public
StmtMutator
{
public:
public:
WSCodeEmitter
(
bool
is_emitting_producer
,
IterVar
thread_iv
,
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
WarpSpecializedRoleMarker
&
marker
)
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
WarpSpecializedRoleMarker
&
marker
)
:
is_emitting_producer_
(
is_emitting_producer
),
buffer_data_to_buffer_
(
buffer_data_to_buffer
),
marker_
(
marker
),
buffer_data_to_buffer_
(
buffer_data_to_buffer
),
marker_
(
marker
),
thread_var_
(
thread_iv
->
var
)
{}
private:
template
<
typename
NodeType
>
Stmt
FilterByRole
(
const
NodeType
*
op
)
{
private:
template
<
typename
NodeType
>
Stmt
FilterByRole
(
const
NodeType
*
op
)
{
Role
role
=
marker_
.
GetRole
(
op
);
if
(
role
==
Role
::
kBoth
)
return
StmtMutator
::
VisitStmt_
(
op
);
...
...
@@ -437,7 +450,7 @@ class WSCodeEmitter : public StmtMutator {
}
// TODO: only need to add block for ops in the loop
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
bool
has_producer
=
false
;
for
(
auto
stmt
:
op
->
seq
)
{
if
(
marker_
.
GetRole
(
stmt
)
==
Role
::
kProducer
)
{
...
...
@@ -445,19 +458,24 @@ class WSCodeEmitter : public StmtMutator {
break
;
}
}
bool
need_producer_sync
=
has_producer
&&
marker_
.
GetRole
(
op
)
==
Role
::
kBoth
;
if
(
!
need_producer_sync
)
return
FilterByRole
(
op
);
bool
need_producer_sync
=
has_producer
&&
marker_
.
GetRole
(
op
)
==
Role
::
kBoth
;
if
(
!
need_producer_sync
)
return
FilterByRole
(
op
);
auto
seq_transformed
=
op
->
seq
.
Map
([
&
](
Stmt
stmt
)
{
return
VisitStmt
(
stmt
);
});
auto
seq_transformed
=
op
->
seq
.
Map
([
&
](
Stmt
stmt
)
{
return
VisitStmt
(
stmt
);
});
auto
map
=
ExtractSyncPattern
(
op
->
seq
);
// std::cout << "Print ExtractSyncPattern" << std::endl;
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " << map.release_after[i] << std::endl;
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
// << map.release_after[i] << std::endl;
// }
// std::cout << "Print sync pattern" << std::endl;
// for (auto pattern : map.patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// }
// std::cout << "End of ExtractSyncPattern" << std::endl;
// pipeline_info_.PrintPipelineInfo();
...
...
@@ -465,29 +483,38 @@ class WSCodeEmitter : public StmtMutator {
Map
<
String
,
ObjectRef
>
annotations
;
annotations
.
Set
(
String
(
"stmt_group"
),
Integer
(
1
));
if
(
is_emitting_producer_
)
{
// producer case
if
(
is_emitting_producer_
)
{
// producer case
ProducerTraitsCollector
collector
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
op
->
seq
.
size
());
i
++
)
{
Array
<
Stmt
>
block_stmt
=
{};
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kConsumer
)
continue
;
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kConsumer
)
continue
;
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kBoth
)
{
block_stmt
.
push_back
(
seq_transformed
[
i
]);
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
continue
;
}
if
(
map
.
acquire
[
i
]
!=
-
1
)
{
PrimExpr
acquire_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
PrimExpr
parity
=
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
PrimExpr
acquire_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
PrimExpr
parity
=
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
block_stmt
.
push_back
(
makeParityWait
(
acquire_barrier_id
,
parity
));
}
ICHECK
(
map
.
release
[
i
]
>=
0
);
PrimExpr
release_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
auto
stmt
=
MbarrierRewriter
::
Rewrite
(
seq_transformed
[
i
],
release_barrier_id
);
PrimExpr
release_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
auto
stmt
=
MbarrierRewriter
::
Rewrite
(
seq_transformed
[
i
],
release_barrier_id
);
collector
.
Collect
(
stmt
);
if
(
!
is_zero
(
collector
.
BulkCopyBytes
()))
{
auto
expect_tx
=
IfThenElse
(
EQ
(
thread_var_
,
0
),
makeExpectTX
(
release_barrier_id
,
collector
.
BulkCopyBytes
()));
auto
expect_tx
=
IfThenElse
(
EQ
(
thread_var_
,
0
),
makeExpectTX
(
release_barrier_id
,
collector
.
BulkCopyBytes
()));
block_stmt
.
push_back
(
expect_tx
);
}
block_stmt
.
push_back
(
stmt
);
...
...
@@ -497,39 +524,53 @@ class WSCodeEmitter : public StmtMutator {
if
(
map
.
release_after
[
i
])
{
block_stmt
.
push_back
(
makeArriveBarrier
(
release_barrier_id
));
for
(
int
j
=
0
;
j
<
num_stages_
;
j
++
)
{
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
}
}
collector
.
Clear
();
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
}
}
else
{
// consumer case
}
else
{
// consumer case
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
op
->
seq
.
size
());
i
++
)
{
Array
<
Stmt
>
block_stmt
=
{};
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kProducer
)
continue
;
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kProducer
)
continue
;
if
(
map
.
acquire
[
i
]
!=
-
1
)
{
PrimExpr
acquire_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
PrimExpr
parity
=
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
PrimExpr
acquire_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
PrimExpr
parity
=
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
block_stmt
.
push_back
(
makeParityWait
(
acquire_barrier_id
,
parity
));
}
block_stmt
.
push_back
(
seq_transformed
[
i
]);
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ?
// block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
if
(
map
.
release_after
[
i
])
{
PrimExpr
release_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
PrimExpr
release_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
block_stmt
.
push_back
(
makeArriveBarrier
(
release_barrier_id
));
for
(
int
j
=
0
;
j
<
num_stages_
;
j
++
)
{
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
}
// Update the pipeline info
// Todo: handle sync
}
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
}
// Filter out the producer stmts
int
cur_id
=
0
;
PipelineInfo
new_pipeline_info
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
.
size
());
i
++
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
.
size
());
i
++
)
{
auto
op_info
=
pipeline_info_
.
op_infos
[
i
];
bool
is_producer
=
false
;
for
(
int
j
=
0
;
j
<
op_info
.
group_size
;
j
++
)
{
...
...
@@ -553,7 +594,7 @@ class WSCodeEmitter : public StmtMutator {
return
new_body
.
size
()
==
1
?
new_body
[
0
]
:
SeqStmt
(
std
::
move
(
new_body
));
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
int
num_stages
=
1
;
auto
num_stages_anno
=
op
->
annotations
.
Get
(
"num_stages"
);
if
(
num_stages_anno
.
defined
())
{
...
...
@@ -565,7 +606,7 @@ class WSCodeEmitter : public StmtMutator {
Array
<
Array
<
Integer
>>
group_info_array
;
Array
<
Integer
>
order_info_array
;
Array
<
Integer
>
stage_info_array
;
auto
group_anno
=
op
->
annotations
.
Get
(
"tl_pipeline_group"
);
if
(
group_anno
.
defined
())
{
group_info_array
=
Downcast
<
Array
<
Array
<
Integer
>>>
(
group_anno
);
...
...
@@ -579,9 +620,11 @@ class WSCodeEmitter : public StmtMutator {
stage_info_array
=
Downcast
<
Array
<
Integer
>>
(
stage_anno
);
}
PipelineInfo
pipeline_info
(
group_info_array
,
order_info_array
,
stage_info_array
);
PipelineInfo
pipeline_info
(
group_info_array
,
order_info_array
,
stage_info_array
);
if
(
pipeline_info
.
op_infos
.
size
()
>
0
)
{
ICHECK
(
pipeline_info_
.
op_infos
.
size
()
==
0
)
<<
"Nested pipeline not supported."
;
ICHECK
(
pipeline_info_
.
op_infos
.
size
()
==
0
)
<<
"Nested pipeline not supported."
;
}
PrimExpr
parity_before
=
std
::
move
(
parity_
);
...
...
@@ -592,13 +635,15 @@ class WSCodeEmitter : public StmtMutator {
num_stages_
=
num_stages
;
pipeline_info_
=
pipeline_info
;
stage_
=
FloorMod
(
op
->
loop_var
-
op
->
min
,
num_stages
);
parity_
=
FloorMod
(
parity_before
*
op
->
extent
+
FloorDiv
(
op
->
loop_var
-
op
->
min
,
num_stages
),
2
);
parity_
=
FloorMod
(
parity_before
*
op
->
extent
+
FloorDiv
(
op
->
loop_var
-
op
->
min
,
num_stages
),
2
);
auto
result
=
FilterByRole
(
op
);
Stmt
grouped_for_node
;
if
(
result
.
as
<
ForNode
>
()
&&
group_anno
.
defined
()
&&
group_info_array
.
size
()
>
0
&&
!
is_emitting_producer_
)
{
if
(
result
.
as
<
ForNode
>
()
&&
group_anno
.
defined
()
&&
group_info_array
.
size
()
>
0
&&
!
is_emitting_producer_
)
{
GroupOpRewriter
group_op_rewriter
(
pipeline_info_
);
auto
for_node
=
Downcast
<
For
>
(
result
);
grouped_for_node
=
group_op_rewriter
(
for_node
);
...
...
@@ -618,7 +663,8 @@ class WSCodeEmitter : public StmtMutator {
for_node
.
CopyOnWrite
()
->
annotations
.
erase
(
"tl_pipeline_order"
);
for_node
.
CopyOnWrite
()
->
annotations
.
erase
(
"tl_pipeline_stage"
);
}
if
(
is_emitting_producer_
||
!
group_anno
.
defined
()
||
group_info_array
.
size
()
==
0
)
{
if
(
is_emitting_producer_
||
!
group_anno
.
defined
()
||
group_info_array
.
size
()
==
0
)
{
return
for_node
;
}
return
grouped_for_node
;
...
...
@@ -626,17 +672,17 @@ class WSCodeEmitter : public StmtMutator {
return
result
;
}
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
ICHECK
(
0
);
return
Stmt
();
}
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
ICHECK
(
0
);
return
Stmt
();
}
...
...
@@ -656,27 +702,32 @@ class WSCodeEmitter : public StmtMutator {
}
};
std
::
vector
<
SyncPattern
>
CreateBaseSyncPairs
(
Array
<
Stmt
>
seq_stmt
,
const
std
::
vector
<
bool
>&
is_producer
)
{
std
::
vector
<
SyncPattern
>
CreateBaseSyncPairs
(
Array
<
Stmt
>
seq_stmt
,
const
std
::
vector
<
bool
>
&
is_producer
)
{
const
int
n
=
seq_stmt
.
size
();
std
::
vector
<
std
::
set
<
const
BufferNode
*>>
reads
,
writes
;
std
::
vector
<
std
::
set
<
const
BufferNode
*>>
reads
,
writes
;
reads
.
reserve
(
n
);
writes
.
reserve
(
n
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
seq_stmt
[
i
]);
auto
access
=
GetBlockAccessRegion
(
block
,
buffer_data_to_buffer_
);
std
::
set
<
const
BufferNode
*>
read_set
,
write_set
;
for
(
auto
region
:
access
[
0
])
read_set
.
insert
(
region
->
buffer
.
get
());
for
(
auto
region
:
access
[
1
])
write_set
.
insert
(
region
->
buffer
.
get
());
std
::
set
<
const
BufferNode
*>
read_set
,
write_set
;
for
(
auto
region
:
access
[
0
])
read_set
.
insert
(
region
->
buffer
.
get
());
for
(
auto
region
:
access
[
1
])
write_set
.
insert
(
region
->
buffer
.
get
());
reads
.
push_back
(
std
::
move
(
read_set
));
writes
.
push_back
(
std
::
move
(
write_set
));
}
auto
intersect_fn
=
[](
const
std
::
set
<
const
BufferNode
*>
&
lhs
,
const
std
::
set
<
const
BufferNode
*>
&
rhs
)
{
auto
intersect_fn
=
[](
const
std
::
set
<
const
BufferNode
*>
&
lhs
,
const
std
::
set
<
const
BufferNode
*>
&
rhs
)
{
for
(
auto
ptr
:
lhs
)
if
(
rhs
.
count
(
ptr
))
return
true
;
if
(
rhs
.
count
(
ptr
))
return
true
;
return
false
;
};
...
...
@@ -686,7 +737,8 @@ class WSCodeEmitter : public StmtMutator {
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
i
+
1
;
j
<
n
;
j
++
)
{
if
(
is_producer
[
i
]
!=
is_producer
[
j
]
&&
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
sync_patterns
.
push_back
({
i
,
j
});
break
;
}
...
...
@@ -701,7 +753,8 @@ class WSCodeEmitter : public StmtMutator {
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
i
;
j
++
)
{
if
(
is_producer
[
i
]
!=
is_producer
[
j
]
&&
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
sync_patterns
.
push_back
({
i
,
j
});
break
;
}
...
...
@@ -712,8 +765,9 @@ class WSCodeEmitter : public StmtMutator {
return
sync_patterns
;
}
static
std
::
vector
<
SyncPattern
>
RemoveUnusedSyncPatterns
(
const
std
::
vector
<
SyncPattern
>&
sync_patterns
,
const
std
::
vector
<
bool
>&
is_producer
)
{
static
std
::
vector
<
SyncPattern
>
RemoveUnusedSyncPatterns
(
const
std
::
vector
<
SyncPattern
>
&
sync_patterns
,
const
std
::
vector
<
bool
>
&
is_producer
)
{
/*
Simplify multiple release-acquire pairs into one
------------------
...
...
@@ -746,7 +800,8 @@ class WSCodeEmitter : public StmtMutator {
std
::
vector
<
SyncPattern
>
sync_pattern_cleaned
;
sync_pattern_cleaned
.
reserve
(
M
);
for
(
int
i
=
0
;
i
<
M
;
i
++
)
if
(
!
removed
[
i
])
sync_pattern_cleaned
.
push_back
(
sync_patterns
[
i
]);
if
(
!
removed
[
i
])
sync_pattern_cleaned
.
push_back
(
sync_patterns
[
i
]);
return
sync_pattern_cleaned
;
}
...
...
@@ -760,10 +815,12 @@ class WSCodeEmitter : public StmtMutator {
}
auto
sync_patterns_base
=
CreateBaseSyncPairs
(
seq_stmt
,
is_producer
);
auto
sync_patterns
=
RemoveUnusedSyncPatterns
(
sync_patterns_base
,
is_producer
);
auto
sync_patterns
=
RemoveUnusedSyncPatterns
(
sync_patterns_base
,
is_producer
);
// for (auto pattern : sync_patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// }
SyncPatternMap
map
;
...
...
@@ -799,7 +856,7 @@ class WSCodeEmitter : public StmtMutator {
const
bool
is_emitting_producer_
;
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
std
::
unordered_set
<
int
>
released_barrier_
;
const
WarpSpecializedRoleMarker
&
marker_
;
const
WarpSpecializedRoleMarker
&
marker_
;
int
num_barriers_
=
0
;
PrimExpr
parity_
=
0
;
...
...
@@ -811,17 +868,18 @@ class WSCodeEmitter : public StmtMutator {
};
class
WarpSpecializedRewriter
:
public
StmtExprMutator
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
auto
T
=
WarpSpecializedRewriter
();
T
.
buffer_lca_
=
DetectBufferAccessLCA
(
f
);
for
(
auto
[
buffer
,
_
]
:
T
.
buffer_lca_
)
T
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
for
(
auto
[
buffer
,
_
]
:
T
.
buffer_lca_
)
T
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
f
.
CopyOnWrite
()
->
body
=
T
(
f
->
body
);
return
f
;
}
private:
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
private:
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
&&
Downcast
<
IterVar
>
(
op
->
node
)
->
thread_tag
==
"threadIdx.x"
)
{
thread_iv_
=
Downcast
<
IterVar
>
(
op
->
node
);
...
...
@@ -839,9 +897,10 @@ class WarpSpecializedRewriter : public StmtExprMutator {
}
}
// If users define a thread binding, we will replace the thread binding with threadIdx.x
// We require the thread binding is threadIdx.x, and the extent is the same as the thread extent
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// If users define a thread binding, we will replace the thread binding with
// threadIdx.x We require the thread binding is threadIdx.x, and the extent is
// the same as the thread extent
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
ICHECK
(
thread_iv_
.
defined
());
For
for_node
=
Downcast
<
For
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
if
(
for_node
->
kind
==
ForKind
::
kThreadBinding
)
{
...
...
@@ -849,14 +908,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
String
thread_tag
=
for_node
->
thread_binding
.
value
()
->
thread_tag
;
ICHECK
(
thread_tag
==
"threadIdx.x"
)
<<
"Only support threadIdx.x"
;
Var
thread_iv
=
Downcast
<
Var
>
(
for_node
->
loop_var
);
Stmt
new_body
=
ThreadIdxRewriter
::
Rewrite
(
for_node
->
body
,
thread_iv
,
thread_iv_
);
Stmt
new_body
=
ThreadIdxRewriter
::
Rewrite
(
for_node
->
body
,
thread_iv
,
thread_iv_
);
return
new_body
;
}
return
for_node
;
}
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
if
(
!
thread_iv_
.
defined
())
{
return
block_realize
;
}
...
...
@@ -877,17 +938,21 @@ class WarpSpecializedRewriter : public StmtExprMutator {
PrimExpr
consumer_thread_extent
=
thread_iv_
->
dom
->
extent
;
PrimExpr
producer_thread_extent
=
thread_iv_
->
dom
->
extent
;
// Need one warp-group for bulk-copy only case
if
(
!
marker
.
HasSimtCopy
())
producer_thread_extent
=
128
;
if
(
!
marker
.
HasSimtCopy
())
producer_thread_extent
=
128
;
// TODO: estimate the correct reg usage.
auto
inc_reg_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
240
,
1
}));
auto
dec_reg_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
24
,
0
}));
auto
inc_reg_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
240
,
1
}));
auto
dec_reg_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
24
,
0
}));
producer_code
=
SeqStmt
({
dec_reg_stmt
,
producer_code
});
consumer_code
=
SeqStmt
({
inc_reg_stmt
,
consumer_code
});
producer_code
=
ThreadIdxRewriter
::
Rewrite
(
producer_code
,
thread_iv_
->
var
,
thread_iv_
->
var
-
consumer_thread_extent
);
producer_code
=
ThreadIdxRewriter
::
Rewrite
(
producer_code
,
thread_iv_
->
var
,
thread_iv_
->
var
-
consumer_thread_extent
);
updated_thread_extent_
=
consumer_thread_extent
+
producer_thread_extent
;
need_update_thread_extent_
=
true
;
...
...
@@ -897,15 +962,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
Array
<
PrimExpr
>
barrier_num_threads
;
barrier_num_threads
.
reserve
(
num_barriers
);
for
(
int
i
=
0
;
i
<
num_barriers
;
i
++
)
{
PrimExpr
arrive_thread_count
=
producer
.
released_barrier_
.
count
(
i
)
?
producer_thread_extent
:
consumer_thread_extent
;
PrimExpr
arrive_thread_count
=
producer
.
released_barrier_
.
count
(
i
)
?
producer_thread_extent
:
consumer_thread_extent
;
barrier_num_threads
.
push_back
(
arrive_thread_count
);
}
Stmt
init_barrier
=
Evaluate
(
Call
(
DataType
::
Handle
(),
CreateListofMBarrierOp
(),
barrier_num_threads
));
Stmt
body
=
IfThenElse
(
GE
(
thread_iv_
->
var
,
consumer_thread_extent
),
producer_code
,
consumer_code
);
Stmt
init_barrier
=
Evaluate
(
Call
(
DataType
::
Handle
(),
CreateListofMBarrierOp
(),
barrier_num_threads
));
Stmt
body
=
IfThenElse
(
GE
(
thread_iv_
->
var
,
consumer_thread_extent
),
producer_code
,
consumer_code
);
// Add an attr here to handle the partial thread count in THreadSync pass.
Array
<
IntImm
>
ws_partition
=
{
Downcast
<
IntImm
>
(
producer_thread_extent
),
Downcast
<
IntImm
>
(
consumer_thread_extent
)};
...
...
@@ -935,7 +1001,8 @@ tvm::transform::Pass WarpSpecialized() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.WarpSpecialized"
,
{});
}
TVM_REGISTER_GLOBAL
(
"tl.transform.WarpSpecialized"
).
set_body_typed
(
WarpSpecialized
);
TVM_REGISTER_GLOBAL
(
"tl.transform.WarpSpecialized"
)
.
set_body_typed
(
WarpSpecialized
);
}
// namespace tl
}
// namespace tvm
}
// namespace tl
}
// namespace tvm
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
549416f7
...
...
@@ -5,7 +5,6 @@ import torch
import
torch.backends
import
tilelang.testing
from
tilelang
import
tvm
as
tvm
from
tvm
import
DataType
import
tilelang
as
TL
import
tilelang.language
as
T
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
...
...
testing/python/kernel/test_tilelang_gemm.py
View file @
549416f7
...
...
@@ -30,13 +30,11 @@ def matmul(
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -169,9 +167,7 @@ def test_gemm_f32f32f32_nn():
def
test_gemm_i8i8i32_nn
():
run_gemm
(
512
,
1024
,
768
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
)
run_gemm
(
512
,
1024
,
768
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
)
def
test_gemm_f16f16f16_tn
():
...
...
@@ -217,9 +213,7 @@ def test_gemm_i8i8i32_tn():
def
test_gemm_f64f64f64_nt
():
run_gemm
(
512
,
512
,
512
,
False
,
True
,
"float64"
,
"float64"
,
"float64"
,
64
,
32
,
16
)
run_gemm
(
512
,
512
,
512
,
False
,
True
,
"float64"
,
"float64"
,
"float64"
,
64
,
32
,
16
)
def
test_gemm_f32f32f32_nt
():
...
...
testing/python/kernel/test_tilelang_gemm_mma_intrinsic.py
View file @
549416f7
...
...
@@ -10,8 +10,7 @@ import tilelang as TL
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,
)
TensorCoreIntrinEmitter
,)
from
tilelang.transform
import
simplify_prim_func
torch
.
manual_seed
(
0
)
...
...
testing/python/primitives/test_tilelang_primitives_mma.py
View file @
549416f7
...
...
@@ -6,6 +6,7 @@ import tilelang.testing
import
tilelang
as
tl
from
tilelang
import
primitives
as
P
def
matmul_ssr
(
M
,
N
,
...
...
@@ -30,13 +31,11 @@ def matmul_ssr(
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
...
@@ -145,13 +144,11 @@ def matmul_rsr(
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
...
@@ -264,13 +261,11 @@ def matmul_rrr(
@
T
.
prim_func
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
...
tilelang/intrinsics/__init__.py
View file @
549416f7
...
...
@@ -11,8 +11,7 @@ from .mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform
,
# noqa: F401
)
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
make_mma_swizzle_layout
# noqa: F401
from
.mfma_layout
import
make_mfma_swizzle_layout
# noqa: F401
tilelang/intrinsics/mma_macro_generator.py
View file @
549416f7
...
...
@@ -14,6 +14,7 @@ from .utils import (
lift
=
convert
# TODO(lei): Add Typing for this file
class
TensorCoreIntrinEmitter
(
object
):
"""
...
...
@@ -75,9 +76,11 @@ class TensorCoreIntrinEmitter(object):
self
.
reduce_k
=
reduce_k
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
if
self
.
warp_rows
==
0
or
self
.
warp_cols
==
0
:
raise
ValueError
(
f
"Invalid threads configuration for this tile shape,
{
self
.
warp_rows
}
x
{
self
.
warp_cols
}
with threads
{
self
.
threads
}
"
)
raise
ValueError
(
f
"Invalid threads configuration for this tile shape,
{
self
.
warp_rows
}
x
{
self
.
warp_cols
}
with threads
{
self
.
threads
}
"
)
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
if
isinstance
(
a_dtype
,
str
):
...
...
@@ -272,12 +275,9 @@ class TensorCoreIntrinEmitter(object):
A_local_buf
.
data
,
k_inner
*
warp_rows
*
local_size_a
+
i
*
local_size_a
,
B_local_buf
.
data
,
k_inner
*
warp_cols
*
local_size_b
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
k_inner
*
warp_cols
*
local_size_b
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
T
.
bool
(
False
),
)
...
...
@@ -328,7 +328,9 @@ class TensorCoreIntrinEmitter(object):
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_bindings
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_bindings
))
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -372,12 +374,14 @@ class TensorCoreIntrinEmitter(object):
elif
matrix
==
"A"
and
not
transposed
:
transform_func
=
ldmatrix_16x32_to_shared_16x32_layout_a
else
:
raise
ValueError
(
"ldmatrix only supports B transposed and A non-transposed for int8"
)
raise
ValueError
(
"ldmatrix only supports B transposed and A non-transposed for int8"
)
else
:
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
shape
=
local_buf
.
shape
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment, but got {}"
.
format
(
local_buf
.
scope
())
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment, but got {}"
.
format
(
local_buf
.
scope
())
if
matrix
==
"A"
:
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_k
...
...
@@ -397,7 +401,8 @@ class TensorCoreIntrinEmitter(object):
transform_func
=
transform_func
if
not
transposed
else
transform_func_trans
warp_size
,
local_size_a
,
local_size_b
=
self
.
WARP_SIZE
,
self
.
local_size_a
,
self
.
local_size_b
local_size
=
local_size_a
if
matrix
==
"A"
else
local_size_b
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
).
inverse
([
warp_size
,
local_size
])
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
).
inverse
([
warp_size
,
local_size
])
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
...
...
@@ -406,29 +411,19 @@ class TensorCoreIntrinEmitter(object):
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
if
is_m_first
:
thread_id
=
(
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
)
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
)
else
:
thread_id
=
(
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
)
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
)
return
thread_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
...
...
@@ -439,21 +434,13 @@ class TensorCoreIntrinEmitter(object):
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
return
(
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
)
return
(
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
)
fragment
=
T
.
Fragment
(
shape
,
...
...
@@ -465,9 +452,7 @@ class TensorCoreIntrinEmitter(object):
print
(
f
"fragment.index:
{
fragment
.
index
}
"
)
return
fragment
def
make_mma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
def
make_mma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
...
...
@@ -500,6 +485,7 @@ class TensorCoreIntrinEmitter(object):
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_size
=
self
.
WARP_SIZE
is_m_first
=
self
.
is_m_first
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
Given the row index `i` and column index `j` in the fragment,
...
...
@@ -514,7 +500,8 @@ class TensorCoreIntrinEmitter(object):
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
if
is_m_first
:
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
else
:
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
return
thread_id
...
...
@@ -527,13 +514,9 @@ class TensorCoreIntrinEmitter(object):
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
...
...
@@ -545,6 +528,7 @@ class TensorCoreIntrinEmitter(object):
forward_index_fn
=
forward_index
,
)
class
TensorCoreIntrinEmitterWithLadderTransform
(
TensorCoreIntrinEmitter
):
"""
To eliminate Python syntax within TIR Macro.
...
...
tilelang/language/allocate.py
View file @
549416f7
...
...
@@ -4,6 +4,7 @@
from
tvm.script
import
tir
as
T
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
...
...
tilelang/language/copy.py
View file @
549416f7
...
...
@@ -6,11 +6,10 @@ from typing import Union, List, Optional
from
tvm
import
tir
from
tvm.script
import
tir
as
T
def
region
(
buffer
:
tir
.
BufferLoad
,
access_type
:
str
,
*
args
:
tir
.
PrimExpr
):
access_type
=
{
"r"
:
1
,
"w"
:
2
,
"rw"
:
3
}[
access_type
]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.region"
),
buffer
,
access_type
,
*
args
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.region"
),
buffer
,
access_type
,
*
args
)
def
buffer_to_tile_region
(
buffer
:
tir
.
Buffer
,
access_type
:
str
):
...
...
@@ -19,20 +18,14 @@ def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
return
region
(
T
.
BufferLoad
(
buffer
,
mins
),
access_type
,
*
extents
)
def
buffer_load_to_tile_region
(
load
:
tir
.
BufferLoad
,
access_type
:
str
,
extents
:
List
[
tir
.
PrimExpr
]
):
def
buffer_load_to_tile_region
(
load
:
tir
.
BufferLoad
,
access_type
:
str
,
extents
:
List
[
tir
.
PrimExpr
]):
return
region
(
load
,
access_type
,
*
extents
)
def
buffer_region_to_tile_region
(
buffer_region
:
tir
.
BufferRegion
,
access_type
:
str
):
def
buffer_region_to_tile_region
(
buffer_region
:
tir
.
BufferRegion
,
access_type
:
str
):
mins
=
[
x
.
min
for
x
in
buffer_region
.
region
]
extents
=
[
x
.
extent
for
x
in
buffer_region
.
region
]
return
region
(
T
.
BufferLoad
(
buffer_region
.
buffer
,
mins
),
access_type
,
*
extents
)
return
region
(
T
.
BufferLoad
(
buffer_region
.
buffer
,
mins
),
access_type
,
*
extents
)
def
copy
(
...
...
@@ -71,9 +64,7 @@ def copy(
src
=
_to_region
(
src
,
"r"
)
dst
=
_to_region
(
dst
,
"w"
)
if
coalesced_width
is
not
None
:
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.copy"
),
src
,
dst
,
coalesced_width
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.copy"
),
src
,
dst
,
coalesced_width
)
else
:
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.copy"
),
src
,
dst
)
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment