Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
549416f7
Commit
549416f7
authored
Jan 11, 2025
by
LeiWang1999
Browse files
Merge branch 'main' of
https://github.com/microsoft/TileLang
into main
parents
4d63633a
7fad4e88
Changes
90
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
853 additions
and
711 deletions
+853
-711
src/transform/legalize_vectorized_loop.cc
src/transform/legalize_vectorized_loop.cc
+9
-8
src/transform/loop_partition.cc
src/transform/loop_partition.cc
+29
-24
src/transform/loop_partition.h
src/transform/loop_partition.h
+5
-4
src/transform/loop_vectorize.cc
src/transform/loop_vectorize.cc
+60
-48
src/transform/loop_vectorize.h
src/transform/loop_vectorize.h
+7
-7
src/transform/lower_hopper_intrin.cc
src/transform/lower_hopper_intrin.cc
+41
-30
src/transform/lower_tile_op.cc
src/transform/lower_tile_op.cc
+66
-45
src/transform/multi_version_buffer_rewriter.cc
src/transform/multi_version_buffer_rewriter.cc
+74
-59
src/transform/pipeline_planning.cc
src/transform/pipeline_planning.cc
+77
-55
src/transform/simplify.cc
src/transform/simplify.cc
+139
-127
src/transform/thread_partial_sync.cc
src/transform/thread_partial_sync.cc
+69
-56
src/transform/warp_specialized_rewriter.cc
src/transform/warp_specialized_rewriter.cc
+218
-151
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+0
-1
testing/python/kernel/test_tilelang_gemm.py
testing/python/kernel/test_tilelang_gemm.py
+6
-12
testing/python/kernel/test_tilelang_gemm_mma_intrinsic.py
testing/python/kernel/test_tilelang_gemm_mma_intrinsic.py
+1
-2
testing/python/primitives/test_tilelang_primitives_mma.py
testing/python/primitives/test_tilelang_primitives_mma.py
+13
-18
tilelang/intrinsics/__init__.py
tilelang/intrinsics/__init__.py
+1
-2
tilelang/intrinsics/mma_macro_generator.py
tilelang/intrinsics/mma_macro_generator.py
+31
-47
tilelang/language/allocate.py
tilelang/language/allocate.py
+1
-0
tilelang/language/copy.py
tilelang/language/copy.py
+6
-15
No files found.
src/transform/legalize_vectorized_loop.cc
View file @
549416f7
...
@@ -30,8 +30,8 @@
...
@@ -30,8 +30,8 @@
#include <queue>
#include <queue>
#include "arith/ir_mutator_with_analyzer.h"
#include "../op/parallel.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
#include "loop_vectorize.h"
...
@@ -43,25 +43,26 @@ using arith::IRMutatorWithAnalyzer;
...
@@ -43,25 +43,26 @@ using arith::IRMutatorWithAnalyzer;
// Class to legalize vectorized loops by transforming them appropriately
// Class to legalize vectorized loops by transforming them appropriately
class
LoopVectorizedLegalizer
:
IRMutatorWithAnalyzer
{
class
LoopVectorizedLegalizer
:
IRMutatorWithAnalyzer
{
public:
public:
// Static method to substitute and transform the given PrimFunc
// Static method to substitute and transform the given PrimFunc
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
arith
::
Analyzer
analyzer
;
arith
::
Analyzer
analyzer
;
// Create an instance of the legalizer with the analyzer
// Create an instance of the legalizer with the analyzer
LoopVectorizedLegalizer
substituter
(
&
analyzer
);
LoopVectorizedLegalizer
substituter
(
&
analyzer
);
// Get a mutable copy of the function node
// Get a mutable copy of the function node
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
// Apply the legalizer to the function body
// Apply the legalizer to the function body
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
return
f
;
return
f
;
}
}
private:
private:
// Constructor initializing the base class with the analyzer
// Constructor initializing the base class with the analyzer
LoopVectorizedLegalizer
(
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
)
{}
LoopVectorizedLegalizer
(
arith
::
Analyzer
*
analyzer
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
)
{}
// Override the VisitStmt_ method to handle ForNode (loop statements)
// Override the VisitStmt_ method to handle ForNode (loop statements)
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// Visit and potentially modify the loop node
// Visit and potentially modify the loop node
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
// If the loop is not vectorized, proceed with the default behavior
// If the loop is not vectorized, proceed with the default behavior
...
@@ -90,5 +91,5 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
...
@@ -90,5 +91,5 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
TVM_REGISTER_GLOBAL
(
"tl.transform.LegalizeVectorizedLoop"
)
TVM_REGISTER_GLOBAL
(
"tl.transform.LegalizeVectorizedLoop"
)
.
set_body_typed
(
LegalizeVectorizedLoop
);
.
set_body_typed
(
LegalizeVectorizedLoop
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/loop_partition.cc
View file @
549416f7
...
@@ -32,29 +32,32 @@ namespace tl {
...
@@ -32,29 +32,32 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
class
BufferIndiceSimplify
:
public
StmtExprMutator
{
class
BufferIndiceSimplify
:
public
StmtExprMutator
{
public:
public:
BufferIndiceSimplify
(
arith
::
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{}
BufferIndiceSimplify
(
arith
::
Analyzer
*
analyzer
)
:
analyzer_
(
analyzer
)
{}
private:
private:
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
auto
visited
=
StmtExprMutator
::
VisitExpr_
(
node
);
auto
visited
=
StmtExprMutator
::
VisitExpr_
(
node
);
auto
n
=
visited
.
as
<
BufferLoad
>
().
value
();
auto
n
=
visited
.
as
<
BufferLoad
>
().
value
();
auto
nptr
=
n
.
CopyOnWrite
();
auto
nptr
=
n
.
CopyOnWrite
();
nptr
->
indices
=
nptr
->
indices
.
Map
([
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
nptr
->
indices
=
nptr
->
indices
.
Map
(
[
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
return
n
;
return
n
;
}
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
auto
visited
=
StmtExprMutator
::
VisitStmt_
(
node
);
auto
visited
=
StmtExprMutator
::
VisitStmt_
(
node
);
auto
n
=
visited
.
as
<
BufferStore
>
().
value
();
auto
n
=
visited
.
as
<
BufferStore
>
().
value
();
auto
nptr
=
n
.
CopyOnWrite
();
auto
nptr
=
n
.
CopyOnWrite
();
nptr
->
indices
=
nptr
->
indices
.
Map
([
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
nptr
->
indices
=
nptr
->
indices
.
Map
(
[
&
](
const
auto
&
e
)
{
return
analyzer_
->
Simplify
(
e
);
});
return
n
;
return
n
;
}
}
arith
::
Analyzer
*
analyzer_
;
arith
::
Analyzer
*
analyzer_
;
};
};
// Rewrite the parallel loop into a common loop, which is mapped to threads
// Rewrite the parallel loop into a common loop, which is mapped to threads
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
)
{
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
)
{
ICHECK
(
loop_layout
.
defined
());
ICHECK
(
loop_layout
.
defined
());
ICHECK
(
thread_var
.
defined
());
ICHECK
(
thread_var
.
defined
());
int
old_loop_depth
=
loop_layout
->
InputDim
();
int
old_loop_depth
=
loop_layout
->
InputDim
();
...
@@ -71,7 +74,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
...
@@ -71,7 +74,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
Map
<
Var
,
PrimExpr
>
vmap
;
Map
<
Var
,
PrimExpr
>
vmap
;
Stmt
body
=
op
;
Stmt
body
=
op
;
auto
inv_loop
=
loop_layout
->
Inverse
();
auto
inv_loop
=
loop_layout
->
Inverse
();
auto
indices
=
inv_loop
->
Forward
(
vars
.
Map
([](
const
Var
&
v
)
{
return
PrimExpr
(
v
);
}));
auto
indices
=
inv_loop
->
Forward
(
vars
.
Map
([](
const
Var
&
v
)
{
return
PrimExpr
(
v
);
}));
for
(
int
i
=
0
;
i
<
old_loop_depth
;
i
++
)
{
for
(
int
i
=
0
;
i
<
old_loop_depth
;
i
++
)
{
ICHECK
(
body
.
as
<
For
>
().
defined
());
ICHECK
(
body
.
as
<
For
>
().
defined
());
For
loop
=
body
.
as
<
For
>
().
value
();
For
loop
=
body
.
as
<
For
>
().
value
();
...
@@ -82,8 +86,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
...
@@ -82,8 +86,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
// substitute and re-construct the serial loop
// substitute and re-construct the serial loop
body
=
Substitute
(
body
,
vmap
);
body
=
Substitute
(
body
,
vmap
);
for
(
int
i
=
new_loop_depth
-
1
;
i
>=
0
;
i
--
)
{
for
(
int
i
=
new_loop_depth
-
1
;
i
>=
0
;
i
--
)
{
body
=
body
=
For
(
vars
[
i
],
make_zero
(
vars
[
i
]
->
dtype
),
inv_loop
->
InputShape
()[
i
],
For
(
vars
[
i
],
make_zero
(
vars
[
i
]
->
dtype
),
inv_loop
->
InputShape
()[
i
],
ForKind
::
kSerial
,
body
);
ForKind
::
kSerial
,
body
);
analyzer
->
Bind
(
vars
[
i
],
Range
(
0
,
inv_loop
->
InputShape
()[
i
]));
analyzer
->
Bind
(
vars
[
i
],
Range
(
0
,
inv_loop
->
InputShape
()[
i
]));
}
}
...
@@ -95,11 +99,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
...
@@ -95,11 +99,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
}
}
class
LoopPramaUnroller
:
public
StmtExprMutator
{
class
LoopPramaUnroller
:
public
StmtExprMutator
{
public:
public:
LoopPramaUnroller
()
=
default
;
LoopPramaUnroller
()
=
default
;
private:
private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
if
(
node
->
kind
==
ForKind
::
kSerial
)
{
if
(
node
->
kind
==
ForKind
::
kSerial
)
{
For
new_for
=
GetRef
<
For
>
(
node
);
For
new_for
=
GetRef
<
For
>
(
node
);
auto
for_ptr
=
new_for
.
CopyOnWrite
();
auto
for_ptr
=
new_for
.
CopyOnWrite
();
...
@@ -112,7 +116,7 @@ class LoopPramaUnroller : public StmtExprMutator {
...
@@ -112,7 +116,7 @@ class LoopPramaUnroller : public StmtExprMutator {
};
};
class
LoopPartitioner
:
public
StmtExprVisitor
{
class
LoopPartitioner
:
public
StmtExprVisitor
{
public:
public:
LoopPartitioner
()
=
default
;
LoopPartitioner
()
=
default
;
Fragment
Partition
(
For
op
,
int
num_thread
,
int
vectorize_size
)
{
Fragment
Partition
(
For
op
,
int
num_thread
,
int
vectorize_size
)
{
...
@@ -129,17 +133,18 @@ class LoopPartitioner : public StmtExprVisitor {
...
@@ -129,17 +133,18 @@ class LoopPartitioner : public StmtExprVisitor {
ICHECK
(
loop_size_full
%
vectorize_size
==
0
);
ICHECK
(
loop_size_full
%
vectorize_size
==
0
);
PrimExpr
access_idx
=
FloorDiv
(
flattened
,
vectorize_size
);
PrimExpr
access_idx
=
FloorDiv
(
flattened
,
vectorize_size
);
PrimExpr
thd
=
FloorMod
(
access_idx
,
num_thread
);
PrimExpr
thd
=
FloorMod
(
access_idx
,
num_thread
);
PrimExpr
idx
=
PrimExpr
idx
=
FloorDiv
(
access_idx
,
num_thread
)
*
vectorize_size
+
FloorDiv
(
access_idx
,
num_thread
)
*
vectorize_size
+
FloorMod
(
flattened
,
vectorize_size
);
FloorMod
(
flattened
,
vectorize_size
);
return
Fragment
(
loop_vars_
,
{
idx
},
{
thd
},
{});
return
Fragment
(
loop_vars_
,
{
idx
},
{
thd
},
{});
}
}
private:
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
if
(
node
->
kind
==
ForKind
::
kParallel
)
{
if
(
node
->
kind
==
ForKind
::
kParallel
)
{
body_
=
node
->
body
;
body_
=
node
->
body
;
loop_vars_
.
push_back
(
IterVar
(
Range
::
FromMinExtent
(
node
->
min
,
node
->
extent
),
node
->
loop_var
,
loop_vars_
.
push_back
(
IterVarType
::
kDataPar
));
IterVar
(
Range
::
FromMinExtent
(
node
->
min
,
node
->
extent
),
node
->
loop_var
,
IterVarType
::
kDataPar
));
}
}
StmtExprVisitor
::
VisitStmt_
(
node
);
StmtExprVisitor
::
VisitStmt_
(
node
);
}
}
...
@@ -160,5 +165,5 @@ For LoopPragmaUnroll(For stmt) {
...
@@ -160,5 +165,5 @@ For LoopPragmaUnroll(For stmt) {
return
unrolled
;
return
unrolled
;
}
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/loop_partition.h
View file @
549416f7
...
@@ -36,13 +36,14 @@ namespace tl {
...
@@ -36,13 +36,14 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
);
For
PartitionLoop
(
For
op
,
Var
thread_var
,
arith
::
Analyzer
*
analyzer
,
Fragment
loop_layout
);
Fragment
PlanLoopPartition
(
For
op
,
size_t
num_thread
,
int
vectorize_size
);
Fragment
PlanLoopPartition
(
For
op
,
size_t
num_thread
,
int
vectorize_size
);
For
LoopPragmaUnroll
(
For
stmt
);
For
LoopPragmaUnroll
(
For
stmt
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif
// TVM_TL_LOOP_PARTITION_H_
#endif // TVM_TL_LOOP_PARTITION_H_
src/transform/loop_vectorize.cc
View file @
549416f7
...
@@ -30,10 +30,10 @@
...
@@ -30,10 +30,10 @@
#include <numeric>
#include <numeric>
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
#include "common/loop_vectorization_utils.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -48,10 +48,10 @@ struct VectorizePlanResult {
...
@@ -48,10 +48,10 @@ struct VectorizePlanResult {
};
};
class
VectorizePlanner
:
public
arith
::
IRVisitorWithAnalyzer
{
class
VectorizePlanner
:
public
arith
::
IRVisitorWithAnalyzer
{
public:
public:
VectorizePlanner
()
=
default
;
VectorizePlanner
()
=
default
;
int
Plan
(
const
For
&
node
)
{
int
Plan
(
const
For
&
node
)
{
this
->
operator
()(
node
);
this
->
operator
()(
node
);
// Always Enable vectorization
// Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1;
// if (!has_nonlocal_memory_access_) return 1;
...
@@ -62,18 +62,19 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -62,18 +62,19 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
PrimExpr
GetCondition
()
{
return
condition_
;
}
PrimExpr
GetCondition
()
{
return
condition_
;
}
private:
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
inner_for_
=
node
;
inner_for_
=
node
;
iter_map_
.
Set
(
node
->
loop_var
,
Range
(
node
->
min
,
node
->
extent
));
iter_map_
.
Set
(
node
->
loop_var
,
Range
(
node
->
min
,
node
->
extent
));
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
}
void
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
void
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
node
->
buffer
.
scope
()
==
"shared.dyn"
)
node
->
buffer
.
scope
()
==
"shared.dyn"
)
has_nonlocal_memory_access_
=
true
;
has_nonlocal_memory_access_
=
true
;
if
(
node
->
buffer
->
shape
.
size
()
==
1
&&
node
->
buffer
->
shape
[
0
].
as
<
IntImmNode
>
()
->
value
==
1
)
{
if
(
node
->
buffer
->
shape
.
size
()
==
1
&&
node
->
buffer
->
shape
[
0
].
as
<
IntImmNode
>
()
->
value
==
1
)
{
// TODO(lei): This should be improved as
// TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register.
// constant buffer that tl hack to use as local register.
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
...
@@ -82,7 +83,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -82,7 +83,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
node
->
buffer
.
scope
()
==
"shared.dyn"
)
node
->
buffer
.
scope
()
==
"shared.dyn"
)
has_nonlocal_memory_access_
=
true
;
has_nonlocal_memory_access_
=
true
;
...
@@ -90,12 +91,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -90,12 +91,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
}
void
VisitStmt_
(
const
IfThenElseNode
*
node
)
final
{
void
VisitStmt_
(
const
IfThenElseNode
*
node
)
final
{
CheckConditionVectorized
(
node
->
condition
);
CheckConditionVectorized
(
node
->
condition
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
}
void
VisitExpr_
(
const
CallNode
*
node
)
final
{
void
VisitExpr_
(
const
CallNode
*
node
)
final
{
if
(
node
->
op
==
builtin
::
if_then_else
())
{
if
(
node
->
op
==
builtin
::
if_then_else
())
{
CheckConditionVectorized
(
node
->
args
[
0
]);
CheckConditionVectorized
(
node
->
args
[
0
]);
}
else
if
(
node
->
op
==
builtin
::
call_extern
())
{
}
else
if
(
node
->
op
==
builtin
::
call_extern
())
{
...
@@ -105,16 +106,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -105,16 +106,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
}
void
CheckConditionVectorized
(
const
PrimExpr
&
cond
)
{
void
CheckConditionVectorized
(
const
PrimExpr
&
cond
)
{
// TODO: perform some checks here
// TODO: perform some checks here
}
}
void
UpdateVectorSize
(
const
Array
<
PrimExpr
>
indices
,
const
Buffer
&
buffer
)
{
void
UpdateVectorSize
(
const
Array
<
PrimExpr
>
indices
,
const
Buffer
&
buffer
)
{
if
(
!
inner_for_
)
return
;
if
(
!
inner_for_
)
return
;
auto
extent_ptr
=
inner_for_
->
extent
.
as
<
IntImmNode
>
();
auto
extent_ptr
=
inner_for_
->
extent
.
as
<
IntImmNode
>
();
if
(
!
extent_ptr
)
return
;
if
(
!
extent_ptr
)
return
;
const
DataType
&
access_type
=
buffer
->
dtype
;
const
DataType
&
access_type
=
buffer
->
dtype
;
// i // 2, i % 8 can also be vectorized as factor 16
// i // 2, i % 8 can also be vectorized as factor 16
int
max_vector_size
=
128
/
access_type
.
bits
();
int
max_vector_size
=
128
/
access_type
.
bits
();
// so we should disable this GCD optimization
// so we should disable this GCD optimization
...
@@ -122,7 +125,8 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -122,7 +125,8 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
auto
last_dim
=
buffer
->
shape
.
back
();
auto
last_dim
=
buffer
->
shape
.
back
();
auto
mod_set
=
analyzer_
.
modular_set
(
last_dim
);
auto
mod_set
=
analyzer_
.
modular_set
(
last_dim
);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block conditionally tail vectorize
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if
(
buffer
->
shape
.
back
().
as
<
IntImmNode
>
())
{
if
(
buffer
->
shape
.
back
().
as
<
IntImmNode
>
())
{
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
coeff
);
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
coeff
);
...
@@ -142,8 +146,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -142,8 +146,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
elem_offset
=
elem_offset
+
indices
[
i
]
*
stride
;
elem_offset
=
elem_offset
+
indices
[
i
]
*
stride
;
stride
=
stride
*
buffer
->
shape
[
i
];
stride
=
stride
*
buffer
->
shape
[
i
];
}
}
while
(
!
IndiceCanVectorize
(
elem_offset
,
inner_for_
->
loop_var
,
inner_for_
->
extent
,
while
(
!
IndiceCanVectorize
(
elem_offset
,
inner_for_
->
loop_var
,
vector_size_
,
&
analyzer_
))
{
inner_for_
->
extent
,
vector_size_
,
&
analyzer_
))
{
vector_size_
/=
2
;
vector_size_
/=
2
;
}
}
}
else
if
(
vector_size_
<=
vector_load_bits_max_
/
buffer
->
dtype
.
bits
())
{
}
else
if
(
vector_size_
<=
vector_load_bits_max_
/
buffer
->
dtype
.
bits
())
{
...
@@ -156,7 +161,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -156,7 +161,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
static
const
int
vector_load_bits_max_
=
128
;
static
const
int
vector_load_bits_max_
=
128
;
const
ForNode
*
inner_for_
;
const
ForNode
*
inner_for_
;
Map
<
Var
,
Range
>
iter_map_
;
Map
<
Var
,
Range
>
iter_map_
;
bool
has_nonlocal_memory_access_
=
false
;
bool
has_nonlocal_memory_access_
=
false
;
int
vector_size_
=
128
;
int
vector_size_
=
128
;
...
@@ -166,12 +171,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
...
@@ -166,12 +171,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
};
};
class
VectorizeDynamicCallRemover
:
public
StmtExprMutator
{
class
VectorizeDynamicCallRemover
:
public
StmtExprMutator
{
public:
public:
VectorizeDynamicCallRemover
(
Var
inner_var
,
int
vector_size
)
VectorizeDynamicCallRemover
(
Var
inner_var
,
int
vector_size
)
:
inner_var_
(
inner_var
),
vector_size_
(
vector_size
)
{}
:
inner_var_
(
inner_var
),
vector_size_
(
vector_size
)
{}
private:
private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
Map
<
Var
,
PrimExpr
>
vmap
;
Map
<
Var
,
PrimExpr
>
vmap
;
...
@@ -191,15 +196,16 @@ class VectorizeDynamicCallRemover : public StmtExprMutator {
...
@@ -191,15 +196,16 @@ class VectorizeDynamicCallRemover : public StmtExprMutator {
};
};
class
VectorizeRewriter
:
public
StmtExprMutator
{
class
VectorizeRewriter
:
public
StmtExprMutator
{
public:
public:
VectorizeRewriter
(
VectorizePlanResult
plan
)
VectorizeRewriter
(
VectorizePlanResult
plan
)
:
vector_size_
(
plan
.
vector_size
),
condition_
(
plan
.
condition
),
dynamic_
(
plan
.
dynamic
)
{}
:
vector_size_
(
plan
.
vector_size
),
condition_
(
plan
.
condition
),
dynamic_
(
plan
.
dynamic
)
{}
private:
private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
inner_for_
=
node
;
inner_for_
=
node
;
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
if
(
inner_for_
==
node
)
{
// rewrite the innermost loop
if
(
inner_for_
==
node
)
{
// rewrite the innermost loop
For
fnode
=
ret
.
as
<
For
>
().
value
();
For
fnode
=
ret
.
as
<
For
>
().
value
();
auto
old_var
=
fnode
->
loop_var
;
auto
old_var
=
fnode
->
loop_var
;
auto
extent_ptr
=
as_const_int
(
fnode
->
extent
);
auto
extent_ptr
=
as_const_int
(
fnode
->
extent
);
...
@@ -208,7 +214,7 @@ class VectorizeRewriter : public StmtExprMutator {
...
@@ -208,7 +214,7 @@ class VectorizeRewriter : public StmtExprMutator {
ICHECK
(
extent
%
vector_size_
==
0
)
ICHECK
(
extent
%
vector_size_
==
0
)
<<
"extent: "
<<
extent
<<
" vector_size_: "
<<
vector_size_
;
<<
"extent: "
<<
extent
<<
" vector_size_: "
<<
vector_size_
;
ICHECK
(
is_zero
(
fnode
->
min
));
ICHECK
(
is_zero
(
fnode
->
min
));
if
(
!
dynamic_
)
{
// check dynamic shape
if
(
!
dynamic_
)
{
// check dynamic shape
if
(
extent
==
vector_size_
)
{
if
(
extent
==
vector_size_
)
{
fnode
.
CopyOnWrite
()
->
kind
=
ForKind
::
kVectorized
;
fnode
.
CopyOnWrite
()
->
kind
=
ForKind
::
kVectorized
;
return
fnode
;
return
fnode
;
...
@@ -219,8 +225,8 @@ class VectorizeRewriter : public StmtExprMutator {
...
@@ -219,8 +225,8 @@ class VectorizeRewriter : public StmtExprMutator {
vmap
.
Set
(
fnode
->
loop_var
,
outer_var
*
vector_size_
+
inner_var
);
vmap
.
Set
(
fnode
->
loop_var
,
outer_var
*
vector_size_
+
inner_var
);
Stmt
body
=
Substitute
(
fnode
->
body
,
vmap
);
Stmt
body
=
Substitute
(
fnode
->
body
,
vmap
);
body
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
body
);
body
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
body
);
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
annotations
,
fnode
->
span
);
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
return
body
;
return
body
;
}
}
}
else
{
}
else
{
...
@@ -237,11 +243,13 @@ class VectorizeRewriter : public StmtExprMutator {
...
@@ -237,11 +243,13 @@ class VectorizeRewriter : public StmtExprMutator {
VectorizeDynamicCallRemover
remover
(
inner_var
,
vector_size_
);
VectorizeDynamicCallRemover
remover
(
inner_var
,
vector_size_
);
body
=
remover
(
body
);
body
=
remover
(
body
);
For
vectorize_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
body
);
For
vectorize_for
=
For
serial_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kSerial
,
body
);
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
body
);
For
serial_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kSerial
,
body
);
body
=
IfThenElse
(
condition
,
vectorize_for
,
serial_for
);
body
=
IfThenElse
(
condition
,
vectorize_for
,
serial_for
);
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
annotations
,
fnode
->
span
);
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
return
body
;
return
body
;
}
}
}
else
{
}
else
{
...
@@ -249,15 +257,15 @@ class VectorizeRewriter : public StmtExprMutator {
...
@@ -249,15 +257,15 @@ class VectorizeRewriter : public StmtExprMutator {
}
}
}
}
const
ForNode
*
inner_for_
;
const
ForNode
*
inner_for_
;
const
int
vector_size_
;
const
int
vector_size_
;
const
PrimExpr
condition_
;
const
PrimExpr
condition_
;
const
bool
dynamic_
;
const
bool
dynamic_
;
};
};
int
GetVectorizeSize
(
const
For
&
loop
)
{
return
VectorizePlanner
().
Plan
(
loop
);
}
int
GetVectorizeSize
(
const
For
&
loop
)
{
return
VectorizePlanner
().
Plan
(
loop
);
}
VectorizePlanResult
GetVectorizePlanResult
(
const
For
&
loop
)
{
VectorizePlanResult
GetVectorizePlanResult
(
const
For
&
loop
)
{
VectorizePlanner
planner
;
VectorizePlanner
planner
;
int
vector_size
=
planner
.
Plan
(
loop
);
int
vector_size
=
planner
.
Plan
(
loop
);
bool
dynamic
=
planner
.
GetDynamic
();
bool
dynamic
=
planner
.
GetDynamic
();
...
@@ -265,16 +273,19 @@ VectorizePlanResult GetVectorizePlanResult(const For& loop) {
...
@@ -265,16 +273,19 @@ VectorizePlanResult GetVectorizePlanResult(const For& loop) {
return
{
vector_size
,
dynamic
,
condition
};
return
{
vector_size
,
dynamic
,
condition
};
}
}
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
int
target_vectorized_size
,
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
arith
::
Analyzer
*
analyzer
)
{
int
target_vectorized_size
,
arith
::
Analyzer
*
analyzer
)
{
ICHECK
(
target_vectorized_size
>=
1
);
ICHECK
(
target_vectorized_size
>=
1
);
if
(
target_vectorized_size
==
1
)
return
true
;
if
(
target_vectorized_size
==
1
)
if
(
!
analyzer
->
CanProveEqual
(
FloorMod
(
iter_var_size
,
target_vectorized_size
),
0
))
return
false
;
return
true
;
if
(
!
analyzer
->
CanProveEqual
(
FloorMod
(
iter_var_size
,
target_vectorized_size
),
0
))
return
false
;
Var
v0
(
"v0"
),
v1
(
"v1"
);
Var
v0
(
"v0"
),
v1
(
"v1"
);
analyzer
->
Bind
(
v0
,
Range
(
0
,
target_vectorized_size
));
analyzer
->
Bind
(
v0
,
Range
(
0
,
target_vectorized_size
));
analyzer
->
Bind
(
v1
,
Range
(
0
,
FloorDiv
(
iter_var_size
,
target_vectorized_size
)));
analyzer
->
Bind
(
v1
,
Range
(
0
,
FloorDiv
(
iter_var_size
,
target_vectorized_size
)));
PrimExpr
expr_transformed
=
PrimExpr
expr_transformed
=
analyzer
->
Simplify
(
analyzer
->
Simplify
(
Substitute
(
expr
,
{{
var
,
v0
+
v1
*
target_vectorized_size
}}));
Substitute
(
expr
,
{{
var
,
v0
+
v1
*
target_vectorized_size
}}));
Vectorizer
vectorizer
(
v0
,
IntImm
(
v0
->
dtype
,
target_vectorized_size
));
Vectorizer
vectorizer
(
v0
,
IntImm
(
v0
->
dtype
,
target_vectorized_size
));
PrimExpr
expr_vectorized
=
vectorizer
.
VisitExpr
(
expr_transformed
);
PrimExpr
expr_vectorized
=
vectorizer
.
VisitExpr
(
expr_transformed
);
...
@@ -290,16 +301,17 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int targ
...
@@ -290,16 +301,17 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int targ
}
}
}
}
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
)
{
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
)
{
VectorizePlanResult
res
{
128
,
false
,
0
};
VectorizePlanResult
res
{
128
,
false
,
0
};
if
(
vectorize_hint
<=
0
)
{
if
(
vectorize_hint
<=
0
)
{
res
=
GetVectorizePlanResult
(
loop
);
res
=
GetVectorizePlanResult
(
loop
);
vectorize_hint
=
res
.
vector_size
;
vectorize_hint
=
res
.
vector_size
;
}
}
if
(
vectorize_hint
==
1
)
return
loop
;
if
(
vectorize_hint
==
1
)
return
loop
;
auto
rewriter
=
VectorizeRewriter
(
res
);
auto
rewriter
=
VectorizeRewriter
(
res
);
return
Downcast
<
For
>
(
rewriter
(
loop
));
return
Downcast
<
For
>
(
rewriter
(
loop
));
}
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/loop_vectorize.h
View file @
549416f7
...
@@ -35,13 +35,13 @@ namespace tl {
...
@@ -35,13 +35,13 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
int
GetVectorizeSize
(
const
For
&
loop
);
int
GetVectorizeSize
(
const
For
&
loop
);
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
=
-
1
);
For
VectorizeLoop
(
const
For
&
loop
,
int
vectorize_hint
=
-
1
);
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
int
target_vectorized_size
,
bool
IndiceCanVectorize
(
PrimExpr
expr
,
Var
var
,
PrimExpr
iter_var_size
,
arith
::
Analyzer
*
analyzer
);
int
target_vectorized_size
,
arith
::
Analyzer
*
analyzer
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
#endif
// TVM_TL_LOOP_VECTORIZE_H_
#endif // TVM_TL_LOOP_VECTORIZE_H_
src/transform/lower_hopper_intrin.cc
View file @
549416f7
...
@@ -37,15 +37,15 @@ namespace tl {
...
@@ -37,15 +37,15 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
class
LowerHopperIntrin
:
public
StmtExprMutator
{
class
LowerHopperIntrin
:
public
StmtExprMutator
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
LowerHopperIntrin
substituter
;
LowerHopperIntrin
substituter
;
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
for
(
auto
[
call
,
var
]
:
substituter
.
desc_map_
)
{
for
(
auto
[
call
,
var
]
:
substituter
.
desc_map_
)
{
// Should allocate 128 bytes for TensorMap on stack
// Should allocate 128 bytes for TensorMap on stack
Call
alloc_desc
=
Call
alloc_desc
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_stack_alloca
(),
Call
(
DataType
::
Handle
(),
builtin
::
tvm_stack_alloca
(),
{
StringImm
(
"arg_value"
),
16
});
{
StringImm
(
"arg_value"
),
16
});
Array
<
PrimExpr
>
init_desc_args
;
Array
<
PrimExpr
>
init_desc_args
;
if
(
call
->
op
.
same_as
(
CreateTMADescriptorOp
()))
{
if
(
call
->
op
.
same_as
(
CreateTMADescriptorOp
()))
{
init_desc_args
.
push_back
(
StringImm
(
tvm_tensormap_create_tiled
));
init_desc_args
.
push_back
(
StringImm
(
tvm_tensormap_create_tiled
));
...
@@ -55,15 +55,19 @@ class LowerHopperIntrin : public StmtExprMutator {
...
@@ -55,15 +55,19 @@ class LowerHopperIntrin : public StmtExprMutator {
CHECK
(
0
)
<<
call
->
op
;
CHECK
(
0
)
<<
call
->
op
;
}
}
init_desc_args
.
push_back
(
var
);
init_desc_args
.
push_back
(
var
);
init_desc_args
.
insert
(
init_desc_args
.
end
(),
call
->
args
.
begin
(),
call
->
args
.
end
());
init_desc_args
.
insert
(
init_desc_args
.
end
(),
call
->
args
.
begin
(),
Call
init_desc
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_call_packed
(),
init_desc_args
);
call
->
args
.
end
());
fptr
->
body
=
LetStmt
(
var
,
alloc_desc
,
SeqStmt
({
Evaluate
(
init_desc
),
fptr
->
body
}));
Call
init_desc
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_call_packed
(),
init_desc_args
);
fptr
->
body
=
LetStmt
(
var
,
alloc_desc
,
SeqStmt
({
Evaluate
(
init_desc
),
fptr
->
body
}));
}
}
return
f
;
return
f
;
}
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
// Insert the prefetch TMA descriptor statement TO the beginning of the kernel
// Insert the prefetch TMA descriptor statement TO the beginning of the
// kernel
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
if
(
iv
->
thread_tag
==
"threadIdx.x"
)
{
if
(
iv
->
thread_tag
==
"threadIdx.x"
)
{
...
@@ -73,18 +77,22 @@ class LowerHopperIntrin : public StmtExprMutator {
...
@@ -73,18 +77,22 @@ class LowerHopperIntrin : public StmtExprMutator {
}
else
{
}
else
{
Array
<
Stmt
>
stmt_seq
;
Array
<
Stmt
>
stmt_seq
;
if
(
!
init_mbarrier_calls_
.
empty
())
{
if
(
!
init_mbarrier_calls_
.
empty
())
{
auto
alloc_mbarrier
=
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
create_barriers
(),
auto
alloc_mbarrier
=
{
static_cast
<
int
>
(
init_mbarrier_calls_
.
size
())}));
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
create_barriers
(),
{
static_cast
<
int
>
(
init_mbarrier_calls_
.
size
())}));
stmt_seq
.
push_back
(
alloc_mbarrier
);
stmt_seq
.
push_back
(
alloc_mbarrier
);
}
}
auto
stmts
=
prefetch_calls_
;
auto
stmts
=
prefetch_calls_
;
stmts
.
insert
(
stmts
.
end
(),
init_mbarrier_calls_
.
begin
(),
init_mbarrier_calls_
.
end
());
stmts
.
insert
(
stmts
.
end
(),
init_mbarrier_calls_
.
begin
(),
auto
init_stmt
=
IfThenElse
(
EQ
(
iv
->
var
,
0
),
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
]);
init_mbarrier_calls_
.
end
());
auto
init_stmt
=
IfThenElse
(
EQ
(
iv
->
var
,
0
),
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
]);
stmt_seq
.
push_back
(
init_stmt
);
stmt_seq
.
push_back
(
init_stmt
);
if
(
!
init_mbarrier_calls_
.
empty
())
{
if
(
!
init_mbarrier_calls_
.
empty
())
{
Stmt
mem_sync
=
Evaluate
(
Stmt
mem_sync
=
Call
(
DataType
::
Handle
(),
builtin
::
tvm_storage_sync
(),
{
StringImm
(
"shared"
)}));
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
tvm_storage_sync
(),
{
StringImm
(
"shared"
)}));
stmt_seq
.
push_back
(
mem_sync
);
stmt_seq
.
push_back
(
mem_sync
);
}
}
stmt_seq
.
push_back
(
body
);
stmt_seq
.
push_back
(
body
);
...
@@ -98,7 +106,7 @@ class LowerHopperIntrin : public StmtExprMutator {
...
@@ -98,7 +106,7 @@ class LowerHopperIntrin : public StmtExprMutator {
return
StmtExprMutator
::
VisitStmt_
(
op
);
return
StmtExprMutator
::
VisitStmt_
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
CallNode
*
call
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
call
)
final
{
if
(
call
->
op
.
same_as
(
CreateTMADescriptorOp
())
||
if
(
call
->
op
.
same_as
(
CreateTMADescriptorOp
())
||
call
->
op
.
same_as
(
CreateTMAIm2ColDescriptorOp
()))
{
call
->
op
.
same_as
(
CreateTMAIm2ColDescriptorOp
()))
{
Var
var
;
Var
var
;
...
@@ -107,10 +115,12 @@ class LowerHopperIntrin : public StmtExprMutator {
...
@@ -107,10 +115,12 @@ class LowerHopperIntrin : public StmtExprMutator {
var
=
iter
->
second
;
var
=
iter
->
second
;
}
else
{
}
else
{
String
name
=
call
->
args
[
2
].
as
<
Var
>
().
value
()
->
name_hint
;
String
name
=
call
->
args
[
2
].
as
<
Var
>
().
value
()
->
name_hint
;
var
=
Var
(
name
+
"_desc"
,
PointerType
(
PrimType
(
cuTensorMapType
()),
"grid_constant"
));
var
=
Var
(
name
+
"_desc"
,
PointerType
(
PrimType
(
cuTensorMapType
()),
"grid_constant"
));
desc_map_
[
GetRef
<
Call
>
(
call
)]
=
var
;
desc_map_
[
GetRef
<
Call
>
(
call
)]
=
var
;
prefetch_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
prefetch_calls_
.
push_back
(
{
StringImm
(
"tl::prefetch_tma_descriptor"
),
var
})));
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
{
StringImm
(
"tl::prefetch_tma_descriptor"
),
var
})));
}
}
return
var
;
return
var
;
}
else
if
(
call
->
op
.
same_as
(
CreateListofMBarrierOp
()))
{
}
else
if
(
call
->
op
.
same_as
(
CreateListofMBarrierOp
()))
{
...
@@ -118,24 +128,25 @@ class LowerHopperIntrin : public StmtExprMutator {
...
@@ -118,24 +128,25 @@ class LowerHopperIntrin : public StmtExprMutator {
int
num_barriers
=
static_cast
<
int
>
(
call
->
args
.
size
());
int
num_barriers
=
static_cast
<
int
>
(
call
->
args
.
size
());
for
(
int
i
=
0
;
i
<
num_barriers
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_barriers
;
i
++
)
{
PrimExpr
mbarrier
=
Call
(
DataType
::
Handle
(),
GetMBarrierOp
(),
{
i
});
PrimExpr
mbarrier
=
Call
(
DataType
::
Handle
(),
GetMBarrierOp
(),
{
i
});
init_mbarrier_calls_
.
push_back
(
init_mbarrier_calls_
.
push_back
(
Evaluate
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
{
mbarrier
,
call
->
args
[
i
]})));
{
mbarrier
,
call
->
args
[
i
]})));
}
}
return
0
;
return
0
;
}
else
if
(
call
->
op
.
same_as
(
SyncThreadsPartialOp
()))
{
}
else
if
(
call
->
op
.
same_as
(
SyncThreadsPartialOp
()))
{
int
barrier_id
=
init_mbarrier_calls_
.
size
();
int
barrier_id
=
init_mbarrier_calls_
.
size
();
PrimExpr
mbarrier
=
Call
(
DataType
::
Handle
(),
GetMBarrierOp
(),
{
barrier_id
});
PrimExpr
mbarrier
=
init_mbarrier_calls_
.
push_back
(
Call
(
DataType
::
Handle
(),
GetMBarrierOp
(),
{
barrier_id
});
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
init_mbarrier_calls_
.
push_back
(
Evaluate
(
{
mbarrier
,
call
->
args
[
0
]})));
Call
(
DataType
::
Handle
(),
builtin
::
ptx_init_barrier_thread_count
(),
{
mbarrier
,
call
->
args
[
0
]})));
return
Call
(
DataType
::
Handle
(),
SyncThreadsPartialOp
(),
{
mbarrier
});
return
Call
(
DataType
::
Handle
(),
SyncThreadsPartialOp
(),
{
mbarrier
});
}
else
{
}
else
{
return
StmtExprMutator
::
VisitExpr_
(
call
);
return
StmtExprMutator
::
VisitExpr_
(
call
);
}
}
}
}
private:
private:
Array
<
Stmt
>
prefetch_calls_
;
Array
<
Stmt
>
prefetch_calls_
;
Array
<
Stmt
>
init_mbarrier_calls_
;
Array
<
Stmt
>
init_mbarrier_calls_
;
std
::
unordered_map
<
Call
,
Var
,
StructuralHash
,
ExprDeepEqual
>
desc_map_
;
std
::
unordered_map
<
Call
,
Var
,
StructuralHash
,
ExprDeepEqual
>
desc_map_
;
...
@@ -154,5 +165,5 @@ tvm::transform::Pass LowerHopperIntrin() {
...
@@ -154,5 +165,5 @@ tvm::transform::Pass LowerHopperIntrin() {
TVM_REGISTER_GLOBAL
(
"tl.transform.LowerHopperIntrin"
)
TVM_REGISTER_GLOBAL
(
"tl.transform.LowerHopperIntrin"
)
.
set_body_typed
(
LowerHopperIntrin
);
.
set_body_typed
(
LowerHopperIntrin
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/lower_tile_op.cc
View file @
549416f7
...
@@ -27,10 +27,10 @@
...
@@ -27,10 +27,10 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../layout/utils.h"
#include "../op/op.h"
#include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
#include "loop_partition.h"
namespace
tvm
{
namespace
tvm
{
...
@@ -38,8 +38,9 @@ namespace tl {
...
@@ -38,8 +38,9 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
static
Buffer
makeBufferWithLayout
(
const
Buffer
&
buffer
,
const
Layout
&
layout
)
{
static
Buffer
makeBufferWithLayout
(
const
Buffer
&
buffer
,
const
Layout
&
layout
)
{
const
auto
*
ptr_type
=
TVM_TYPE_AS
(
buffer
->
data
->
type_annotation
,
PointerTypeNode
);
const
auto
*
ptr_type
=
TVM_TYPE_AS
(
buffer
->
data
->
type_annotation
,
PointerTypeNode
);
Type
new_type
;
Type
new_type
;
// convert fragments to normal local buffer
// convert fragments to normal local buffer
if
(
ptr_type
->
storage_scope
==
"local.fragment"
)
{
if
(
ptr_type
->
storage_scope
==
"local.fragment"
)
{
...
@@ -53,32 +54,33 @@ static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) {
...
@@ -53,32 +54,33 @@ static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) {
}
else
{
}
else
{
new_var
=
Var
(
buffer
->
data
->
name_hint
,
new_type
);
new_var
=
Var
(
buffer
->
data
->
name_hint
,
new_type
);
}
}
return
Buffer
(
new_var
,
buffer
->
dtype
,
layout
->
OutputShape
(),
{},
buffer
->
elem_offset
,
return
Buffer
(
new_var
,
buffer
->
dtype
,
layout
->
OutputShape
(),
{},
buffer
->
name
,
buffer
->
data_alignment
,
buffer
->
offset_factor
,
buffer
->
buffer_type
);
buffer
->
elem_offset
,
buffer
->
name
,
buffer
->
data_alignment
,
buffer
->
offset_factor
,
buffer
->
buffer_type
);
}
}
class
LowerTileOpPass
:
arith
::
IRMutatorWithAnalyzer
{
class
LowerTileOpPass
:
arith
::
IRMutatorWithAnalyzer
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
arith
::
Analyzer
analyzer
;
arith
::
Analyzer
analyzer
;
LowerTileOpPass
substituter
(
&
analyzer
);
LowerTileOpPass
substituter
(
&
analyzer
);
// Trace the buffer map for tvm_access_ptr
// Trace the buffer map for tvm_access_ptr
substituter
.
buffer_map_
.
insert
(
f
->
buffer_map
.
begin
(),
f
->
buffer_map
.
end
());
substituter
.
buffer_map_
.
insert
(
f
->
buffer_map
.
begin
(),
f
->
buffer_map
.
end
());
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
substituter
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
substituter
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
}
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
<<
"LowerTileOpPass: Require the target attribute"
;
ICHECK
(
target
.
defined
())
<<
"LowerTileOpPass: Require the target attribute"
;
substituter
.
target_
=
target
.
value
();
substituter
.
target_
=
target
.
value
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
return
f
;
return
f
;
}
}
private:
private:
using
arith
::
IRMutatorWithAnalyzer
::
IRMutatorWithAnalyzer
;
using
arith
::
IRMutatorWithAnalyzer
::
IRMutatorWithAnalyzer
;
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
// Record the mapping from buffer data var to buffer for later lookup
// Record the mapping from buffer data var to buffer for later lookup
for
(
auto
buffer
:
op
->
alloc_buffers
)
{
for
(
auto
buffer
:
op
->
alloc_buffers
)
{
buffer_map_
.
insert
({
buffer
->
data
,
buffer
});
buffer_map_
.
insert
({
buffer
->
data
,
buffer
});
...
@@ -91,7 +93,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -91,7 +93,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
}
}
Map
<
Var
,
Layout
>
vmap
;
Map
<
Var
,
Layout
>
vmap
;
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
auto
layout_map
=
op
->
annotations
.
at
(
attr
::
kLayoutMap
).
as
<
Map
<
Buffer
,
Layout
>>
().
value
();
auto
layout_map
=
op
->
annotations
.
at
(
attr
::
kLayoutMap
)
.
as
<
Map
<
Buffer
,
Layout
>>
()
.
value
();
for
(
auto
[
buffer
,
layout
]
:
layout_map
)
{
for
(
auto
[
buffer
,
layout
]
:
layout_map
)
{
buffer_remap_
.
Set
(
buffer
,
makeBufferWithLayout
(
buffer
,
layout
));
buffer_remap_
.
Set
(
buffer
,
makeBufferWithLayout
(
buffer
,
layout
));
layout_map_
.
Set
(
buffer
,
layout
);
layout_map_
.
Set
(
buffer
,
layout
);
...
@@ -105,7 +109,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -105,7 +109,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
block_ptr
->
alloc_buffers
.
Set
(
i
,
buffer_remap_
[
buffer
]);
block_ptr
->
alloc_buffers
.
Set
(
i
,
buffer_remap_
[
buffer
]);
}
}
}
}
for
(
const
auto
&
buffer
:
workspaces_
)
block_ptr
->
alloc_buffers
.
push_back
(
buffer
);
for
(
const
auto
&
buffer
:
workspaces_
)
block_ptr
->
alloc_buffers
.
push_back
(
buffer
);
workspaces_
.
clear
();
workspaces_
.
clear
();
block_ptr
->
annotations
.
erase
(
attr
::
kLayoutMap
);
block_ptr
->
annotations
.
erase
(
attr
::
kLayoutMap
);
return
block
;
return
block
;
...
@@ -113,18 +118,19 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -113,18 +118,19 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
int
CheckAndGetBufferRowSize
(
Buffer
buffer
)
{
int
CheckAndGetBufferRowSize
(
Buffer
buffer
)
{
CHECK
(
buffer
->
shape
.
size
()
>=
2
)
CHECK
(
buffer
->
shape
.
size
()
>=
2
)
<<
"The dimension of Buffer
\"
"
<<
buffer
->
name
<<
"
\"
with shape "
<<
buffer
->
shape
<<
"The dimension of Buffer
\"
"
<<
buffer
->
name
<<
"
\"
with shape "
<<
" should be at least 2"
;
<<
buffer
->
shape
<<
" should be at least 2"
;
auto
dim
=
buffer
->
shape
.
size
();
auto
dim
=
buffer
->
shape
.
size
();
auto
buffer_row_size
=
buffer
->
shape
[
dim
-
1
].
as
<
IntImmNode
>
()
->
value
;
auto
buffer_row_size
=
buffer
->
shape
[
dim
-
1
].
as
<
IntImmNode
>
()
->
value
;
return
buffer_row_size
;
return
buffer_row_size
;
}
}
PrimExpr
HandleAccessPtrAndOffset
(
PrimExpr
access_ptr
,
Optional
<
PrimExpr
>
offset
=
NullOpt
,
PrimExpr
HandleAccessPtrAndOffset
(
PrimExpr
access_ptr
,
Optional
<
PrimExpr
>
offset
=
NullOpt
,
DataType
dtype
=
DataType
::
Int
(
32
))
{
DataType
dtype
=
DataType
::
Int
(
32
))
{
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
accumulate it to
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// smem_offset
//
accumulate it to
smem_offset
CHECK
(
access_ptr
->
IsInstance
<
CallNode
>
())
CHECK
(
access_ptr
->
IsInstance
<
CallNode
>
())
<<
"Invalid access ptr for permuted layout: "
<<
access_ptr
;
<<
"Invalid access ptr for permuted layout: "
<<
access_ptr
;
auto
access_ptr_call
=
Downcast
<
Call
>
(
access_ptr
);
auto
access_ptr_call
=
Downcast
<
Call
>
(
access_ptr
);
...
@@ -136,8 +142,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -136,8 +142,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
Array
<
PrimExpr
>
shape
=
load
->
buffer
->
shape
;
Array
<
PrimExpr
>
shape
=
load
->
buffer
->
shape
;
CHECK_EQ
(
indices
.
size
(),
shape
.
size
())
CHECK_EQ
(
indices
.
size
(),
shape
.
size
())
<<
"Indices size and shape size must match for general N-dimensional buffer "
<<
"Indices size and shape size must match for general N-dimensional "
<<
"but got indices size: "
<<
indices
.
size
()
<<
" and shape size: "
<<
shape
.
size
();
"buffer "
<<
"but got indices size: "
<<
indices
.
size
()
<<
" and shape size: "
<<
shape
.
size
();
PrimExpr
elem_offset
=
0
;
PrimExpr
elem_offset
=
0
;
PrimExpr
stride
=
1
;
PrimExpr
stride
=
1
;
...
@@ -147,13 +155,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -147,13 +155,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
stride
*=
shape
[
i
];
stride
*=
shape
[
i
];
}
}
PrimExpr
smem_offset
=
elem_offset
+
(
offset
.
defined
()
?
offset
.
value
()
:
0
);
PrimExpr
smem_offset
=
elem_offset
+
(
offset
.
defined
()
?
offset
.
value
()
:
0
);
auto
new_buffer
=
buffer_remap_
[
load
->
buffer
];
auto
new_buffer
=
buffer_remap_
[
load
->
buffer
];
auto
buffer_map_iter
=
buffer_map_
.
find
(
Downcast
<
Var
>
(
load
->
buffer
->
data
));
auto
buffer_map_iter
=
buffer_map_
.
find
(
Downcast
<
Var
>
(
load
->
buffer
->
data
));
CHECK
(
buffer_map_iter
!=
buffer_map_
.
end
())
CHECK
(
buffer_map_iter
!=
buffer_map_
.
end
())
<<
"The buffer corresponding to data Var "
<<
access_ptr_call
->
args
[
0
]
<<
" is not found"
;
<<
"The buffer corresponding to data Var "
<<
access_ptr_call
->
args
[
0
]
<<
" is not found"
;
int
buffer_row_size
=
CheckAndGetBufferRowSize
(
buffer_map_iter
->
second
);
int
buffer_row_size
=
CheckAndGetBufferRowSize
(
buffer_map_iter
->
second
);
(
void
)
buffer_row_size
;
(
void
)
buffer_row_size
;
...
@@ -163,11 +174,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -163,11 +174,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
PrimExpr
remaining_offset
=
smem_offset
;
PrimExpr
remaining_offset
=
smem_offset
;
for
(
int
i
=
static_cast
<
int
>
(
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
static_cast
<
int
>
(
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
multi_dim_indices
.
insert
(
multi_dim_indices
.
begin
(),
floormod
(
remaining_offset
,
shape
[
i
]));
multi_dim_indices
.
insert
(
multi_dim_indices
.
begin
(),
floormod
(
remaining_offset
,
shape
[
i
]));
remaining_offset
=
floordiv
(
remaining_offset
,
shape
[
i
]);
remaining_offset
=
floordiv
(
remaining_offset
,
shape
[
i
]);
}
}
auto
forward_indices
=
layout_map_
[
load
->
buffer
]
->
Forward
(
multi_dim_indices
);
auto
forward_indices
=
layout_map_
[
load
->
buffer
]
->
Forward
(
multi_dim_indices
);
PrimExpr
new_offset
=
0
;
PrimExpr
new_offset
=
0
;
PrimExpr
stride_offset
=
1
;
PrimExpr
stride_offset
=
1
;
for
(
int
i
=
static_cast
<
int
>
(
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
static_cast
<
int
>
(
shape
.
size
())
-
1
;
i
>=
0
;
--
i
)
{
...
@@ -191,8 +204,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -191,8 +204,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
access_ptr_call
;
return
access_ptr_call
;
}
}
PrimExpr
VisitExpr_
(
const
tir
::
CallNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
tir
::
CallNode
*
op
)
final
{
if
(
!
op
->
op
.
same_as
(
builtin
::
ptx_ldmatrix
())
&&
!
op
->
op
.
same_as
(
builtin
::
mma_store
()))
{
if
(
!
op
->
op
.
same_as
(
builtin
::
ptx_ldmatrix
())
&&
!
op
->
op
.
same_as
(
builtin
::
mma_store
()))
{
return
Downcast
<
Call
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
return
Downcast
<
Call
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
}
else
{
}
else
{
is_ptx_
=
true
;
is_ptx_
=
true
;
...
@@ -212,15 +226,18 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -212,15 +226,18 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
address_of_call
->
args
[
0
]);
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
address_of_call
->
args
[
0
]);
if
(
buffer_remap_
.
count
(
load
->
buffer
))
{
if
(
buffer_remap_
.
count
(
load
->
buffer
))
{
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
smem_offset
,
call
->
dtype
);
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
smem_offset
,
call
->
dtype
);
auto
new_call
=
call
.
CopyOnWrite
();
auto
new_call
=
call
.
CopyOnWrite
();
new_call
->
args
.
Set
(
5
,
new_access_ptr
);
new_call
->
args
.
Set
(
5
,
new_access_ptr
);
new_call
->
args
.
Set
(
6
,
IntImm
(
smem_offset
->
dtype
,
0
));
new_call
->
args
.
Set
(
6
,
IntImm
(
smem_offset
->
dtype
,
0
));
}
}
}
else
if
(
call
->
op
.
same_as
(
builtin
::
mma_store
()))
{
}
else
if
(
call
->
op
.
same_as
(
builtin
::
mma_store
()))
{
// because we will directly store result to Buffer instead of calling mma_store now
// because we will directly store result to Buffer instead of calling
// mma_store now
auto
access_ptr
=
call
->
args
[
2
];
auto
access_ptr
=
call
->
args
[
2
];
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
NullOpt
,
call
->
dtype
);
auto
new_access_ptr
=
HandleAccessPtrAndOffset
(
access_ptr
,
NullOpt
,
call
->
dtype
);
auto
new_call
=
call
.
CopyOnWrite
();
auto
new_call
=
call
.
CopyOnWrite
();
new_call
->
args
.
Set
(
2
,
new_access_ptr
);
new_call
->
args
.
Set
(
2
,
new_access_ptr
);
}
else
{
}
else
{
...
@@ -230,7 +247,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -230,7 +247,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
call
;
return
call
;
}
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
Downcast
<
BufferLoad
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
auto
load
=
Downcast
<
BufferLoad
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
if
(
is_ptx_
)
{
if
(
is_ptx_
)
{
return
load
;
return
load
;
...
@@ -243,7 +260,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -243,7 +260,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
load
;
return
load
;
}
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
Downcast
<
BufferStore
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
auto
store
=
Downcast
<
BufferStore
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
if
(
buffer_remap_
.
count
(
store
->
buffer
))
{
if
(
buffer_remap_
.
count
(
store
->
buffer
))
{
auto
new_indices
=
layout_map_
[
store
->
buffer
]
->
Forward
(
store
->
indices
);
auto
new_indices
=
layout_map_
[
store
->
buffer
]
->
Forward
(
store
->
indices
);
...
@@ -253,36 +270,40 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
...
@@ -253,36 +270,40 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return
store
;
return
store
;
}
}
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
auto
var
=
Downcast
<
Var
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
auto
var
=
Downcast
<
Var
>
(
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
if
(
buffer_data_to_buffer_
.
count
(
var
))
{
if
(
buffer_data_to_buffer_
.
count
(
var
))
{
auto
buffer
=
buffer_data_to_buffer_
[
var
];
auto
buffer
=
buffer_data_to_buffer_
[
var
];
if
(
buffer_remap_
.
count
(
buffer
))
return
buffer_remap_
[
buffer
]
->
data
;
if
(
buffer_remap_
.
count
(
buffer
))
return
buffer_remap_
[
buffer
]
->
data
;
}
}
return
var
;
return
var
;
}
}
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
();
const
CallNode
*
call
=
op
->
value
.
as
<
CallNode
>
();
// Do not analysis the call node to the global function.
// Do not analysis the call node to the global function.
if
(
call
&&
call
->
op
.
as
<
GlobalVarNode
>
())
if
(
call
&&
call
->
op
.
as
<
GlobalVarNode
>
())
return
Downcast
<
Evaluate
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
return
Downcast
<
Evaluate
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
auto
tile_op
=
ParseOperator
(
GetRef
<
Stmt
>
(
op
),
buffer_data_to_buffer_
);
auto
tile_op
=
ParseOperator
(
GetRef
<
Stmt
>
(
op
),
buffer_data_to_buffer_
);
if
(
tile_op
==
nullptr
)
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
if
(
tile_op
==
nullptr
)
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
AddWorkspaceCallback
callback
=
[
this
](
int
num_elem
,
DataType
dtype
)
{
AddWorkspaceCallback
callback
=
[
this
](
int
num_elem
,
DataType
dtype
)
{
auto
workspace
=
decl_buffer
({
PrimExpr
(
num_elem
)},
dtype
,
"workspace"
,
"shared.dyn"
);
auto
workspace
=
decl_buffer
({
PrimExpr
(
num_elem
)},
dtype
,
"workspace"
,
"shared.dyn"
);
workspaces_
.
push_back
(
workspace
);
workspaces_
.
push_back
(
workspace
);
return
workspace
.
access_ptr
(
2
);
// write
return
workspace
.
access_ptr
(
2
);
// write
};
};
auto
lowered
=
tile_op
->
Lower
(
auto
lowered
=
LowerArgs
{
target_
,
thread_block_size_
,
thread_var_
,
callback
,
layout_map_
,
buffer_remap_
},
tile_op
->
Lower
(
LowerArgs
{
target_
,
thread_block_size_
,
thread_var_
,
analyzer_
);
callback
,
layout_map_
,
buffer_remap_
},
analyzer_
);
return
IRMutatorWithAnalyzer
::
VisitStmt
(
lowered
);
return
IRMutatorWithAnalyzer
::
VisitStmt
(
lowered
);
}
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
ICHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
ICHECK_NE
(
iv
->
thread_tag
.
length
(),
0U
);
...
@@ -321,7 +342,7 @@ tvm::transform::Pass LowerTileOp() {
...
@@ -321,7 +342,7 @@ tvm::transform::Pass LowerTileOp() {
}
}
TVM_REGISTER_GLOBAL
(
"tl.transform.LowerTileOp"
).
set_body_typed
(
LowerTileOp
);
TVM_REGISTER_GLOBAL
(
"tl.transform.LowerTileOp"
).
set_body_typed
(
LowerTileOp
);
}
// namespace transform
}
// namespace transform
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/multi_version_buffer_rewriter.cc
View file @
549416f7
...
@@ -38,22 +38,23 @@ using namespace tir;
...
@@ -38,22 +38,23 @@ using namespace tir;
enum
class
Role
{
kConsumer
,
kProducer
,
kBoth
};
enum
class
Role
{
kConsumer
,
kProducer
,
kBoth
};
class
WarpSpecializedRoleMarker_
:
public
StmtVisitor
{
class
WarpSpecializedRoleMarker_
:
public
StmtVisitor
{
public:
public:
WarpSpecializedRoleMarker_
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
)
WarpSpecializedRoleMarker_
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
)
:
buffer_data_to_buffer_
(
buffer_data_to_buffer
)
{}
:
buffer_data_to_buffer_
(
buffer_data_to_buffer
)
{}
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
auto
it
=
map_
.
find
(
stmt
);
auto
it
=
map_
.
find
(
stmt
);
ICHECK
(
it
!=
map_
.
end
());
ICHECK
(
it
!=
map_
.
end
());
return
it
->
second
;
return
it
->
second
;
}
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
Role
role
=
Role
::
kConsumer
;
Role
role
=
Role
::
kConsumer
;
if
(
auto
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
auto
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
role
=
Role
::
kProducer
;
role
=
Role
::
kProducer
;
has_bulk_copy_
=
true
;
has_bulk_copy_
=
true
;
}
}
...
@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
...
@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
if
(
!
is_shared_store
)
{
if
(
!
is_shared_store
)
{
SetRole
(
op
,
Role
::
kConsumer
);
SetRole
(
op
,
Role
::
kConsumer
);
return
;
return
;
...
@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
...
@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
break
;
break
;
}
}
}
}
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
seq
[
0
]);
auto
role
=
GetRole
(
op
->
seq
[
0
]);
for
(
auto
stmt
:
op
->
seq
)
{
for
(
auto
stmt
:
op
->
seq
)
{
...
@@ -96,48 +99,48 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
...
@@ -96,48 +99,48 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
then_case
);
auto
role
=
GetRole
(
op
->
then_case
);
if
(
op
->
else_case
.
defined
())
{
if
(
op
->
else_case
.
defined
())
{
auto
role_else
=
GetRole
(
op
->
else_case
.
value
());
auto
role_else
=
GetRole
(
op
->
else_case
.
value
());
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
}
}
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
block
));
SetRole
(
op
,
GetRole
(
op
->
block
));
}
}
template
<
class
NodeType
>
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
body
));
SetRole
(
op
,
GetRole
(
op
->
body
));
}
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
bool
HasProducer
()
{
return
has_simt_copy_
||
has_bulk_copy_
;
}
bool
HasProducer
()
{
return
has_simt_copy_
||
has_bulk_copy_
;
}
bool
HasSimtCopy
()
{
return
has_simt_copy_
;
}
bool
HasSimtCopy
()
{
return
has_simt_copy_
;
}
private:
private:
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
bool
has_simt_copy_
=
false
;
bool
has_simt_copy_
=
false
;
bool
has_bulk_copy_
=
false
;
bool
has_bulk_copy_
=
false
;
};
};
class
MultiVersionBufferRewriter
:
public
StmtExprMutator
{
class
MultiVersionBufferRewriter
:
public
StmtExprMutator
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
static
PrimFunc
Substitute
(
PrimFunc
&
f
)
{
auto
rewriter
=
MultiVersionBufferRewriter
();
auto
rewriter
=
MultiVersionBufferRewriter
();
rewriter
.
buffer_lca_
=
DetectBufferAccessLCA
(
f
);
rewriter
.
buffer_lca_
=
DetectBufferAccessLCA
(
f
);
for
(
auto
[
buffer
,
_
]
:
rewriter
.
buffer_lca_
)
{
for
(
auto
[
buffer
,
_
]
:
rewriter
.
buffer_lca_
)
{
...
@@ -148,40 +151,45 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
...
@@ -148,40 +151,45 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
f
;
return
f
;
}
}
private:
private:
MultiVersionBufferRewriter
()
=
default
;
MultiVersionBufferRewriter
()
=
default
;
Array
<
Buffer
>
GetVersionedBuffers
(
Array
<
Stmt
>
seq_stmt
,
Array
<
Buffer
>
scoped_buffers
)
{
Array
<
Buffer
>
GetVersionedBuffers
(
Array
<
Stmt
>
seq_stmt
,
Array
<
Buffer
>
scoped_buffers
)
{
std
::
vector
<
Role
>
roles
;
std
::
vector
<
Role
>
roles
;
Array
<
Array
<
BufferRegion
>>
reads
,
writes
;
Array
<
Array
<
BufferRegion
>>
reads
,
writes
;
auto
marker
=
WarpSpecializedRoleMarker_
(
buffer_data_to_buffer_
);
auto
marker
=
WarpSpecializedRoleMarker_
(
buffer_data_to_buffer_
);
for
(
auto
stmt
:
seq_stmt
)
{
for
(
auto
stmt
:
seq_stmt
)
{
marker
(
stmt
);
marker
(
stmt
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
);
auto
access
=
GetBlockAccessRegion
(
block
,
buffer_data_to_buffer_
);
auto
access
=
GetBlockAccessRegion
(
block
,
buffer_data_to_buffer_
);
reads
.
push_back
(
std
::
move
(
access
[
0
]));
reads
.
push_back
(
std
::
move
(
access
[
0
]));
writes
.
push_back
(
std
::
move
(
access
[
1
]));
writes
.
push_back
(
std
::
move
(
access
[
1
]));
roles
.
push_back
(
marker
.
GetRole
(
stmt
));
roles
.
push_back
(
marker
.
GetRole
(
stmt
));
}
}
std
::
unordered_set
<
const
BufferNode
*>
consumer_used
,
producer_used
;
std
::
unordered_set
<
const
BufferNode
*>
consumer_used
,
producer_used
;
for
(
size_t
i
=
0
;
i
<
seq_stmt
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
seq_stmt
.
size
();
i
++
)
{
if
(
roles
[
i
]
==
Role
::
kProducer
)
{
if
(
roles
[
i
]
==
Role
::
kProducer
)
{
for
(
BufferRegion
br
:
writes
[
i
])
producer_used
.
insert
(
br
->
buffer
.
get
());
for
(
BufferRegion
br
:
writes
[
i
])
producer_used
.
insert
(
br
->
buffer
.
get
());
}
else
{
}
else
{
for
(
BufferRegion
br
:
reads
[
i
])
consumer_used
.
insert
(
br
->
buffer
.
get
());
for
(
BufferRegion
br
:
reads
[
i
])
consumer_used
.
insert
(
br
->
buffer
.
get
());
}
}
}
}
Array
<
Buffer
>
versioned_buffers
;
Array
<
Buffer
>
versioned_buffers
;
for
(
Buffer
buffer
:
scoped_buffers
)
{
for
(
Buffer
buffer
:
scoped_buffers
)
{
if
(
consumer_used
.
count
(
buffer
.
get
())
&&
producer_used
.
count
(
buffer
.
get
()))
{
if
(
consumer_used
.
count
(
buffer
.
get
())
&&
producer_used
.
count
(
buffer
.
get
()))
{
versioned_buffers
.
push_back
(
buffer
);
versioned_buffers
.
push_back
(
buffer
);
}
}
}
}
return
versioned_buffers
;
return
versioned_buffers
;
}
}
static
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
static
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
if
(
new_buffer
->
strides
.
size
())
{
if
(
new_buffer
->
strides
.
size
())
{
...
@@ -192,8 +200,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
...
@@ -192,8 +200,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
Buffer
(
new_buffer
);
return
Buffer
(
new_buffer
);
}
}
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
Block
block
=
block_realize
->
block
;
Block
block
=
block_realize
->
block
;
Array
<
Buffer
>
alloc_buffers
;
Array
<
Buffer
>
alloc_buffers
;
for
(
auto
buffer
:
block
->
alloc_buffers
)
{
for
(
auto
buffer
:
block
->
alloc_buffers
)
{
...
@@ -209,24 +218,27 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
...
@@ -209,24 +218,27 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
block_realize
;
return
block_realize
;
}
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
auto
num_stages_anno
=
op
->
annotations
.
Get
(
"num_stages"
);
auto
num_stages_anno
=
op
->
annotations
.
Get
(
"num_stages"
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
op
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
op
);
ICHECK
(
num_stages_anno
.
as
<
IntImmNode
>
());
ICHECK
(
num_stages_anno
.
as
<
IntImmNode
>
());
int
num_stages
=
static_cast
<
int
>
(
num_stages_anno
.
as
<
IntImmNode
>
()
->
value
);
int
num_stages
=
static_cast
<
int
>
(
num_stages_anno
.
as
<
IntImmNode
>
()
->
value
);
const
SeqStmtNode
*
pipeline_body_seq
=
op
->
body
.
as
<
SeqStmtNode
>
();
const
SeqStmtNode
*
pipeline_body_seq
=
op
->
body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline "
<<
"ValueError: The body of the software pipeline
should be SeqStmt, got "
"
should be SeqStmt, got "
<<
op
->
body
->
GetTypeKey
();
<<
op
->
body
->
GetTypeKey
();
Array
<
Buffer
>
scoped_buffers
=
{};
Array
<
Buffer
>
scoped_buffers
=
{};
for
(
auto
[
buffer
,
stmt
]
:
buffer_lca_
)
{
for
(
auto
[
buffer
,
stmt
]
:
buffer_lca_
)
{
if
(
stmt
.
defined
()
&&
stmt
.
value
().
get
()
==
op
)
scoped_buffers
.
push_back
(
buffer
);
if
(
stmt
.
defined
()
&&
stmt
.
value
().
get
()
==
op
)
scoped_buffers
.
push_back
(
buffer
);
}
}
Array
<
Buffer
>
versioned_buffers
=
GetVersionedBuffers
(
pipeline_body_seq
->
seq
,
scoped_buffers
);
Array
<
Buffer
>
versioned_buffers
=
GetVersionedBuffers
(
pipeline_body_seq
->
seq
,
scoped_buffers
);
for
(
auto
buffer
:
versioned_buffers
)
{
for
(
auto
buffer
:
versioned_buffers
)
{
Var
buffer_var
=
buffer
->
data
;
Var
buffer_var
=
buffer
->
data
;
...
@@ -239,33 +251,33 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
...
@@ -239,33 +251,33 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
for_node
;
return
for_node
;
}
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
auto
it
=
buffer_remap_
.
find
(
load
->
buffer
);
auto
it
=
buffer_remap_
.
find
(
load
->
buffer
);
if
(
it
==
buffer_remap_
.
end
())
{
if
(
it
==
buffer_remap_
.
end
())
{
return
std
::
move
(
load
);
return
std
::
move
(
load
);
}
}
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
load
.
CopyOnWrite
();
auto
*
n
=
load
.
CopyOnWrite
();
n
->
buffer
=
new_buffer
;
n
->
buffer
=
new_buffer
;
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version_index_
);
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version_index_
);
return
std
::
move
(
load
);
return
std
::
move
(
load
);
}
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
BufferStore
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
BufferStore
store
=
Downcast
<
BufferStore
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
auto
it
=
buffer_remap_
.
find
(
store
->
buffer
);
auto
it
=
buffer_remap_
.
find
(
store
->
buffer
);
if
(
it
==
buffer_remap_
.
end
())
{
if
(
it
==
buffer_remap_
.
end
())
{
return
std
::
move
(
store
);
return
std
::
move
(
store
);
}
}
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
auto
*
n
=
store
.
CopyOnWrite
();
auto
*
n
=
store
.
CopyOnWrite
();
n
->
buffer
=
new_buffer
;
n
->
buffer
=
new_buffer
;
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version_index_
);
n
->
indices
.
insert
(
n
->
indices
.
begin
(),
version_index_
);
return
std
::
move
(
store
);
return
std
::
move
(
store
);
}
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
Call
call
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
Call
call
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
return
RewriteBufferAccess
(
call
,
{
1
});
return
RewriteBufferAccess
(
call
,
{
1
});
...
@@ -273,20 +285,23 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
...
@@ -273,20 +285,23 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return
call
;
return
call
;
}
}
PrimExpr
RewriteBufferAccess
(
const
Call
&
call
,
const
std
::
vector
<
int
>
arg_indices
)
{
PrimExpr
RewriteBufferAccess
(
const
Call
&
call
,
auto
product
=
[](
const
Array
<
PrimExpr
>&
input
)
{
const
std
::
vector
<
int
>
arg_indices
)
{
return
foldl
([](
PrimExpr
a
,
PrimExpr
b
,
Span
span
)
{
return
mul
(
a
,
b
,
span
);
},
auto
product
=
[](
const
Array
<
PrimExpr
>
&
input
)
{
make_const
(
DataType
::
Int
(
32
),
1
),
input
);
return
foldl
(
[](
PrimExpr
a
,
PrimExpr
b
,
Span
span
)
{
return
mul
(
a
,
b
,
span
);
},
make_const
(
DataType
::
Int
(
32
),
1
),
input
);
};
};
Array
<
PrimExpr
>
new_args
=
call
->
args
;
Array
<
PrimExpr
>
new_args
=
call
->
args
;
for
(
int
i
:
arg_indices
)
{
for
(
int
i
:
arg_indices
)
{
auto
buffer_var
=
Downcast
<
Var
>
(
call
->
args
[
i
]);
auto
buffer_var
=
Downcast
<
Var
>
(
call
->
args
[
i
]);
if
(
!
buffer_data_to_buffer_
.
count
(
buffer_var
))
continue
;
if
(
!
buffer_data_to_buffer_
.
count
(
buffer_var
))
const
Buffer
&
buffer
=
buffer_data_to_buffer_
[
buffer_var
];
continue
;
const
Buffer
&
buffer
=
buffer_data_to_buffer_
[
buffer_var
];
auto
it
=
buffer_remap_
.
find
(
buffer
);
auto
it
=
buffer_remap_
.
find
(
buffer
);
if
(
it
!=
buffer_remap_
.
end
())
{
if
(
it
!=
buffer_remap_
.
end
())
{
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
Buffer
&
new_buffer
=
(
*
it
).
second
;
const
PrimExpr
&
old_index
=
call
->
args
[
i
+
1
];
const
PrimExpr
&
old_index
=
call
->
args
[
i
+
1
];
PrimExpr
offset
;
PrimExpr
offset
;
if
(
new_buffer
->
strides
.
empty
())
{
if
(
new_buffer
->
strides
.
empty
())
{
offset
=
product
(
buffer
->
shape
);
offset
=
product
(
buffer
->
shape
);
...
@@ -318,5 +333,5 @@ tvm::transform::Pass MultiVersionBuffer() {
...
@@ -318,5 +333,5 @@ tvm::transform::Pass MultiVersionBuffer() {
TVM_REGISTER_GLOBAL
(
"tl.transform.MultiVersionBuffer"
)
TVM_REGISTER_GLOBAL
(
"tl.transform.MultiVersionBuffer"
)
.
set_body_typed
(
MultiVersionBuffer
);
.
set_body_typed
(
MultiVersionBuffer
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/pipeline_planning.cc
View file @
549416f7
...
@@ -56,22 +56,23 @@ bool MayConflict(Region region1, Region region2) {
...
@@ -56,22 +56,23 @@ bool MayConflict(Region region1, Region region2) {
return
true
;
return
true
;
}
}
}
// namespace
}
// namespace
class
PipelinePlanner
:
public
StmtExprMutator
{
class
PipelinePlanner
:
public
StmtExprMutator
{
public:
public:
static
Stmt
Substitute
(
const
PrimFunc
&
f
)
{
static
Stmt
Substitute
(
const
PrimFunc
&
f
)
{
PipelinePlanner
substituter
;
PipelinePlanner
substituter
;
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
substituter
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
substituter
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
}
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
<<
"Pipeline_Planning: Require the target attribute"
;
ICHECK
(
target
.
defined
())
<<
"Pipeline_Planning: Require the target attribute"
;
substituter
.
target_
=
target
.
value
();
substituter
.
target_
=
target
.
value
();
return
substituter
.
VisitStmt
(
f
->
body
);
return
substituter
.
VisitStmt
(
f
->
body
);
}
}
private:
private:
PipelinePlanner
()
=
default
;
PipelinePlanner
()
=
default
;
struct
PipelineStageInfo
{
struct
PipelineStageInfo
{
...
@@ -83,8 +84,10 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -83,8 +84,10 @@ class PipelinePlanner : public StmtExprMutator {
};
};
PipelineStageInfo
MakePipelineStageInfo
(
Stmt
stmt
,
int
idx
)
{
PipelineStageInfo
MakePipelineStageInfo
(
Stmt
stmt
,
int
idx
)
{
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
/*body*/
stmt
);
Array
<
Array
<
BufferRegion
>>
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
PipelineStageInfo
pinfo
;
PipelineStageInfo
pinfo
;
pinfo
.
reads
=
std
::
move
(
access
[
0
]);
pinfo
.
reads
=
std
::
move
(
access
[
0
]);
pinfo
.
writes
=
std
::
move
(
access
[
1
]);
pinfo
.
writes
=
std
::
move
(
access
[
1
]);
...
@@ -93,22 +96,25 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -93,22 +96,25 @@ class PipelinePlanner : public StmtExprMutator {
// copy stage should only have one reads and one writes
// copy stage should only have one reads and one writes
if
(
pinfo
.
reads
.
size
()
==
1
&&
pinfo
.
writes
.
size
()
==
1
)
{
if
(
pinfo
.
reads
.
size
()
==
1
&&
pinfo
.
writes
.
size
()
==
1
)
{
for
(
auto
region
:
pinfo
.
reads
)
for
(
auto
region
:
pinfo
.
reads
)
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
for
(
auto
region
:
pinfo
.
writes
)
for
(
auto
region
:
pinfo
.
writes
)
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
if
(
region
->
buffer
.
scope
()
==
"global"
)
pinfo
.
copy_stage
=
true
;
}
}
return
std
::
move
(
pinfo
);
return
std
::
move
(
pinfo
);
}
}
Stmt
VisitStmt_
(
const
ForNode
*
loop
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
loop
)
final
{
auto
num_stages_anno
=
loop
->
annotations
.
Get
(
"num_stages"
);
auto
num_stages_anno
=
loop
->
annotations
.
Get
(
"num_stages"
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
loop
);
if
(
!
num_stages_anno
.
defined
())
return
StmtExprMutator
::
VisitStmt_
(
loop
);
int
num_stages
=
num_stages_anno
.
as
<
IntImmNode
>
()
->
value
;
int
num_stages
=
num_stages_anno
.
as
<
IntImmNode
>
()
->
value
;
Stmt
pipeline_body
{
nullptr
};
Stmt
pipeline_body
{
nullptr
};
if
(
const
auto
*
realize
=
loop
->
body
.
as
<
BlockRealizeNode
>
())
{
if
(
const
auto
*
realize
=
loop
->
body
.
as
<
BlockRealizeNode
>
())
{
const
auto
&
block
=
realize
->
block
;
const
auto
&
block
=
realize
->
block
;
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
for
(
const
auto
&
buffer
:
block
->
alloc_buffers
)
{
ICHECK
(
buffer
->
IsInstance
<
BufferNode
>
());
ICHECK
(
buffer
->
IsInstance
<
BufferNode
>
());
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
}
...
@@ -116,10 +122,10 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -116,10 +122,10 @@ class PipelinePlanner : public StmtExprMutator {
}
else
{
}
else
{
pipeline_body
=
loop
->
body
;
pipeline_body
=
loop
->
body
;
}
}
const
SeqStmtNode
*
pipeline_body_seq
=
pipeline_body
.
as
<
SeqStmtNode
>
();
const
SeqStmtNode
*
pipeline_body_seq
=
pipeline_body
.
as
<
SeqStmtNode
>
();
CHECK
(
pipeline_body_seq
)
CHECK
(
pipeline_body_seq
)
<<
"ValueError: The body of the software pipeline "
<<
"ValueError: The body of the software pipeline
should be SeqStmt, got "
"
should be SeqStmt, got "
<<
loop
->
body
->
GetTypeKey
();
<<
loop
->
body
->
GetTypeKey
();
CHECK
(
num_stages
>=
1
);
CHECK
(
num_stages
>=
1
);
CHECK
(
loop
->
kind
==
ForKind
::
kSerial
);
CHECK
(
loop
->
kind
==
ForKind
::
kSerial
);
...
@@ -130,21 +136,28 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -130,21 +136,28 @@ class PipelinePlanner : public StmtExprMutator {
}
}
// analysis use-def chain
// analysis use-def chain
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
int
i
=
pinfo
.
original_order
+
1
;
i
<
static_cast
<
int
>
(
pipeline_body_seq
->
size
());
i
++
)
{
for
(
int
i
=
pinfo
.
original_order
+
1
;
if
(
!
pinfo
.
copy_stage
)
continue
;
i
<
static_cast
<
int
>
(
pipeline_body_seq
->
size
());
i
++
)
{
for
(
const
BufferRegion
&
read
:
pipeline_stage_infos
[
i
].
reads
)
{
if
(
!
pinfo
.
copy_stage
)
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
[
&
](
const
BufferRegion
&
r
)
{
continue
;
return
r
->
buffer
==
read
->
buffer
&&
MayConflict
(
r
->
region
,
read
->
region
);
for
(
const
BufferRegion
&
read
:
pipeline_stage_infos
[
i
].
reads
)
{
})
!=
pinfo
.
writes
.
end
())
{
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
[
&
](
const
BufferRegion
&
r
)
{
return
r
->
buffer
==
read
->
buffer
&&
MayConflict
(
r
->
region
,
read
->
region
);
})
!=
pinfo
.
writes
.
end
())
{
pinfo
.
last_use_stage
=
std
::
max
(
pinfo
.
last_use_stage
,
i
);
pinfo
.
last_use_stage
=
std
::
max
(
pinfo
.
last_use_stage
,
i
);
}
}
}
}
for
(
const
BufferRegion
&
write
:
pipeline_stage_infos
[
i
].
writes
)
{
for
(
const
BufferRegion
&
write
:
pipeline_stage_infos
[
i
].
writes
)
{
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
[
&
](
const
BufferRegion
&
r
)
{
if
(
std
::
find_if
(
pinfo
.
writes
.
begin
(),
pinfo
.
writes
.
end
(),
return
r
->
buffer
==
write
->
buffer
&&
MayConflict
(
r
->
region
,
write
->
region
);
[
&
](
const
BufferRegion
&
r
)
{
})
!=
pinfo
.
writes
.
end
())
{
return
r
->
buffer
==
write
->
buffer
&&
CHECK
(
false
)
<<
"Can't handle multiple write on overlap buffer region in the pipeline "
MayConflict
(
r
->
region
,
write
->
region
);
})
!=
pinfo
.
writes
.
end
())
{
CHECK
(
false
)
<<
"Can't handle multiple write on overlap buffer "
"region in the pipeline "
"planning pass: "
"planning pass: "
<<
pipeline_body_seq
->
seq
[
pinfo
.
original_order
];
<<
pipeline_body_seq
->
seq
[
pinfo
.
original_order
];
}
}
...
@@ -154,28 +167,32 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -154,28 +167,32 @@ class PipelinePlanner : public StmtExprMutator {
// Making stages and orders
// Making stages and orders
int
order_idx
=
0
;
int
order_idx
=
0
;
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
if
(
pinfo
.
copy_stage
&&
pinfo
.
last_use_stage
!=
-
1
)
continue
;
if
(
pinfo
.
copy_stage
&&
pinfo
.
last_use_stage
!=
-
1
)
continue
;
pinfo
.
order
=
order_idx
++
;
pinfo
.
order
=
order_idx
++
;
pinfo
.
stage
=
num_stages
;
pinfo
.
stage
=
num_stages
;
for
(
auto
&
pinfo_1
:
pipeline_stage_infos
)
{
for
(
auto
&
pinfo_1
:
pipeline_stage_infos
)
{
if
(
pinfo_1
.
copy_stage
&&
pinfo_1
.
last_use_stage
==
pinfo
.
original_order
)
{
if
(
pinfo_1
.
copy_stage
&&
pinfo_1
.
last_use_stage
==
pinfo
.
original_order
)
{
pinfo_1
.
order
=
order_idx
++
;
pinfo_1
.
order
=
order_idx
++
;
pinfo_1
.
stage
=
0
;
pinfo_1
.
stage
=
0
;
}
}
}
}
}
}
ICHECK
(
size_t
(
order_idx
)
==
pipeline_stage_infos
.
size
())
<<
ICHECK
(
size_t
(
order_idx
)
==
pipeline_stage_infos
.
size
())
"The number of stages should be equal to the number of pipeline stages. "
<<
<<
"The number of stages should be equal to the number of pipeline "
"Got "
<<
order_idx
<<
" stages and "
<<
pipeline_stage_infos
.
size
()
<<
" pipeline stages."
;
"stages. "
<<
"Got "
<<
order_idx
<<
" stages and "
<<
pipeline_stage_infos
.
size
()
<<
" pipeline stages."
;
// if all the copy is at the end of the order, we can move these copy to the
beginning of the
// if all the copy is at the end of the order, we can move these copy to the
// order and shrink the stage offset by 1.
//
beginning of the
order and shrink the stage offset by 1.
int
copy_stage_at_end
=
[
&
]()
{
int
copy_stage_at_end
=
[
&
]()
{
int
copy_stage_cnt
=
0
;
int
copy_stage_cnt
=
0
;
int
copy_order_min
=
pipeline_stage_infos
.
size
();
int
copy_order_min
=
pipeline_stage_infos
.
size
();
int
non_copy_order_max
=
0
;
int
non_copy_order_max
=
0
;
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
if
(
pinfo
.
copy_stage
)
{
if
(
pinfo
.
copy_stage
)
{
copy_stage_cnt
++
;
copy_stage_cnt
++
;
copy_order_min
=
std
::
min
(
copy_order_min
,
pinfo
.
order
);
copy_order_min
=
std
::
min
(
copy_order_min
,
pinfo
.
order
);
...
@@ -183,19 +200,22 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -183,19 +200,22 @@ class PipelinePlanner : public StmtExprMutator {
non_copy_order_max
=
std
::
max
(
non_copy_order_max
,
pinfo
.
order
);
non_copy_order_max
=
std
::
max
(
non_copy_order_max
,
pinfo
.
order
);
}
}
}
}
if
(
copy_order_min
>
non_copy_order_max
)
return
copy_stage_cnt
;
if
(
copy_order_min
>
non_copy_order_max
)
return
copy_stage_cnt
;
return
-
1
;
return
-
1
;
}();
}();
if
(
copy_stage_at_end
>
0
&&
num_stages
>=
2
)
{
if
(
copy_stage_at_end
>
0
&&
num_stages
>=
2
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
// move copy to the beginning
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
// move copy to the beginning
pinfo
.
order
=
(
pinfo
.
order
+
copy_stage_at_end
)
%
pipeline_stage_infos
.
size
();
pinfo
.
order
=
if
(
!
pinfo
.
copy_stage
)
pinfo
.
stage
--
;
(
pinfo
.
order
+
copy_stage_at_end
)
%
pipeline_stage_infos
.
size
();
if
(
!
pinfo
.
copy_stage
)
pinfo
.
stage
--
;
}
}
}
}
// Finally, make the pipeline annotation
// Finally, make the pipeline annotation
Map
<
String
,
ObjectRef
>
annotations
;
Map
<
String
,
ObjectRef
>
annotations
;
for
(
const
auto
&
[
key
,
value
]
:
loop
->
annotations
)
{
for
(
const
auto
&
[
key
,
value
]
:
loop
->
annotations
)
{
if
(
key
!=
"num_stages"
)
{
if
(
key
!=
"num_stages"
)
{
annotations
.
Set
(
key
,
value
);
annotations
.
Set
(
key
,
value
);
}
}
...
@@ -204,7 +224,7 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -204,7 +224,7 @@ class PipelinePlanner : public StmtExprMutator {
std
::
vector
<
Integer
>
orders
,
stages
;
std
::
vector
<
Integer
>
orders
,
stages
;
orders
.
reserve
(
pipeline_stage_infos
.
size
());
orders
.
reserve
(
pipeline_stage_infos
.
size
());
stages
.
reserve
(
pipeline_stage_infos
.
size
());
stages
.
reserve
(
pipeline_stage_infos
.
size
());
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
for
(
auto
&
pinfo
:
pipeline_stage_infos
)
{
orders
.
push_back
(
pinfo
.
order
);
orders
.
push_back
(
pinfo
.
order
);
stages
.
push_back
(
pinfo
.
stage
);
stages
.
push_back
(
pinfo
.
stage
);
}
}
...
@@ -212,18 +232,19 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -212,18 +232,19 @@ class PipelinePlanner : public StmtExprMutator {
annotations
.
Set
(
tir
::
attr
::
software_pipeline_stage
,
Array
<
Integer
>
(
stages
));
annotations
.
Set
(
tir
::
attr
::
software_pipeline_stage
,
Array
<
Integer
>
(
stages
));
annotations
.
Set
(
tir
::
attr
::
software_pipeline_order
,
Array
<
Integer
>
(
orders
));
annotations
.
Set
(
tir
::
attr
::
software_pipeline_order
,
Array
<
Integer
>
(
orders
));
if
(
TargetHasAsyncCopy
(
target_
))
if
(
TargetHasAsyncCopy
(
target_
))
annotations
.
Set
(
tir
::
attr
::
software_pipeline_async_stages
,
Array
<
Integer
>
{
0
});
annotations
.
Set
(
tir
::
attr
::
software_pipeline_async_stages
,
Array
<
Integer
>
{
0
});
return
For
(
loop
->
loop_var
,
loop
->
min
,
loop
->
extent
,
loop
->
kind
,
loop
->
body
,
return
For
(
loop
->
loop_var
,
loop
->
min
,
loop
->
extent
,
loop
->
kind
,
loop
->
body
,
loop
->
thread_binding
,
annotations
);
loop
->
thread_binding
,
annotations
);
}
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
}
}
Block
block
=
Downcast
<
Block
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
Block
block
=
Downcast
<
Block
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
buffer_data_to_buffer_
.
erase
(
buffer
->
data
);
}
}
return
std
::
move
(
block
);
return
std
::
move
(
block
);
...
@@ -236,14 +257,15 @@ class PipelinePlanner : public StmtExprMutator {
...
@@ -236,14 +257,15 @@ class PipelinePlanner : public StmtExprMutator {
tvm
::
transform
::
Pass
PipelinePlanning
()
{
tvm
::
transform
::
Pass
PipelinePlanning
()
{
using
namespace
tir
::
transform
;
using
namespace
tir
::
transform
;
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
PipelinePlanner
::
Substitute
(
f
);
fptr
->
body
=
PipelinePlanner
::
Substitute
(
f
);
return
f
;
return
f
;
};
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.PipelinePlanning"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.PipelinePlanning"
,
{});
}
}
TVM_REGISTER_GLOBAL
(
"tl.transform.PipelinePlanning"
).
set_body_typed
(
PipelinePlanning
);
TVM_REGISTER_GLOBAL
(
"tl.transform.PipelinePlanning"
)
.
set_body_typed
(
PipelinePlanning
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/simplify.cc
View file @
549416f7
...
@@ -6,15 +6,15 @@
...
@@ -6,15 +6,15 @@
* \brief Remove useless parameters of TL PrimFunc.
* \brief Remove useless parameters of TL PrimFunc.
*/
*/
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/utils.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include "tir/analysis/control_flow_graph.h"
#include "tir/analysis/control_flow_graph.h"
#include "tir/analysis/var_use_def_analysis.h"
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
...
@@ -31,19 +31,19 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
...
@@ -31,19 +31,19 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
TVM_DECLARE_ATTRS
(
SimplifyConfigNode
,
"tl.transform.SimplifyConfig"
)
{
TVM_DECLARE_ATTRS
(
SimplifyConfigNode
,
"tl.transform.SimplifyConfig"
)
{
TVM_ATTR_FIELD
(
transitively_prove_inequalities
)
TVM_ATTR_FIELD
(
transitively_prove_inequalities
)
.
describe
(
.
describe
(
"If true, simplify conditionals with transitive combinations "
"If true, simplify conditionals with transitive combinations
of scoped constraints"
)
"
of scoped constraints"
)
.
set_default
(
false
);
.
set_default
(
false
);
TVM_ATTR_FIELD
(
propagate_knowns_to_prove_conditional
)
TVM_ATTR_FIELD
(
propagate_knowns_to_prove_conditional
)
.
describe
(
.
describe
(
"If true, known buffer values are propagated and used to "
"If true, known buffer values are propagated and used to
statically prove conditionals"
)
"
statically prove conditionals"
)
.
set_default
(
false
);
.
set_default
(
false
);
TVM_ATTR_FIELD
(
propagate_knowns_to_simplify_expressions
)
TVM_ATTR_FIELD
(
propagate_knowns_to_simplify_expressions
)
.
describe
(
.
describe
(
"If true, known buffer values are propagated and used to "
"If true, known buffer values are propagated and used to
replace BufferLoad wherever "
"
replace BufferLoad wherever "
"possible"
)
"possible"
)
.
set_default
(
false
);
.
set_default
(
false
);
TVM_ATTR_FIELD
(
convert_boolean_to_and_of_ors
)
TVM_ATTR_FIELD
(
convert_boolean_to_and_of_ors
)
...
@@ -51,102 +51,103 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
...
@@ -51,102 +51,103 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
.
set_default
(
false
);
.
set_default
(
false
);
TVM_ATTR_FIELD
(
apply_constraints_to_boolean_branches
)
TVM_ATTR_FIELD
(
apply_constraints_to_boolean_branches
)
.
describe
(
.
describe
(
"If true, simplify each branch of AND/OR "
"If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch"
)
"under a constraints provided by the other branch"
)
.
set_default
(
false
);
.
set_default
(
false
);
}
}
RewriteSimplifier
::
Extension
GetEnabledExtensions
()
const
{
RewriteSimplifier
::
Extension
GetEnabledExtensions
()
const
{
RewriteSimplifier
::
Extension
flags
=
RewriteSimplifier
::
kNone
;
RewriteSimplifier
::
Extension
flags
=
RewriteSimplifier
::
kNone
;
if
(
transitively_prove_inequalities
)
{
if
(
transitively_prove_inequalities
)
{
flags
=
flags
=
RewriteSimplifier
::
Extension
(
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kTransitivelyProveInequalities
);
flags
|
RewriteSimplifier
::
kTransitivelyProveInequalities
);
}
}
if
(
convert_boolean_to_and_of_ors
)
{
if
(
convert_boolean_to_and_of_ors
)
{
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kConvertBooleanToAndOfOrs
);
flags
=
RewriteSimplifier
::
Extension
(
flags
|
RewriteSimplifier
::
kConvertBooleanToAndOfOrs
);
}
}
if
(
apply_constraints_to_boolean_branches
)
{
if
(
apply_constraints_to_boolean_branches
)
{
flags
=
RewriteSimplifier
::
Extension
(
flags
|
flags
=
RewriteSimplifier
::
Extension
(
RewriteSimplifier
::
kApplyConstraintsToBooleanBranches
);
flags
|
RewriteSimplifier
::
kApplyConstraintsToBooleanBranches
);
}
}
return
flags
;
return
flags
;
}
}
};
};
std
::
unordered_set
<
const
BufferNode
*>
CollectUsedBuffers
(
const
PrimFunc
&
func
)
{
std
::
unordered_set
<
const
BufferNode
*>
CollectUsedBuffers
(
const
PrimFunc
&
func
)
{
struct
Visitor
:
StmtExprVisitor
{
struct
Visitor
:
StmtExprVisitor
{
using
StmtExprVisitor
::
VisitExpr_
;
using
StmtExprVisitor
::
VisitExpr_
;
using
StmtExprVisitor
::
VisitStmt_
;
using
StmtExprVisitor
::
VisitStmt_
;
Visitor
(
PrimFunc
func
)
:
func
(
func
)
{}
Visitor
(
PrimFunc
func
)
:
func
(
func
)
{}
void
VisitExpr_
(
const
CallNode
*
op
)
override
{
void
VisitExpr_
(
const
CallNode
*
op
)
override
{
for
(
const
auto
&
arg
:
op
->
args
)
{
for
(
const
auto
&
arg
:
op
->
args
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
Downcast
<
PrimExpr
>
(
it
.
second
.
get
()
->
data
).
same_as
(
arg
))
{
if
(
Downcast
<
PrimExpr
>
(
it
.
second
.
get
()
->
data
).
same_as
(
arg
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
}
}
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitExpr_
(
op
);
StmtExprVisitor
::
VisitExpr_
(
op
);
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
void
VisitStmt_
(
const
BlockNode
*
op
)
override
{
void
VisitStmt_
(
const
BlockNode
*
op
)
override
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
buffer
:
op
->
alloc_buffers
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
.
get
()
->
data
))
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
.
get
()
->
data
))
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
}
}
}
for
(
const
auto
&
buffer
:
op
->
reads
)
{
}
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
for
(
const
auto
&
buffer
:
op
->
reads
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
)
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
))
{
}
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
}
}
for
(
const
auto
&
buffer
:
op
->
writes
)
{
}
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
for
(
const
auto
&
buffer
:
op
->
writes
)
{
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
)
)
{
for
(
const
auto
&
it
:
func
->
buffer_map
)
{
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
if
(
it
.
second
.
get
()
->
data
.
same_as
(
buffer
->
buffer
.
get
()
->
data
))
{
}
used_in_buffer_def_
.
insert
(
it
.
second
.
get
());
}
}
}
}
StmtExprVisitor
::
VisitStmt_
(
op
);
}
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
void
VisitBuffer
(
const
Buffer
&
buf
)
{
void
VisitBuffer
(
const
Buffer
&
buf
)
{
// Collect buffers that should remain defined
// Collect buffers that should remain defined
VarUseDefAnalyzer
usage
(
Array
<
Var
>
{});
VarUseDefAnalyzer
usage
(
Array
<
Var
>
{});
usage
(
buf
->
data
);
usage
(
buf
->
data
);
for
(
const
auto
&
dim
:
buf
->
shape
)
{
for
(
const
auto
&
dim
:
buf
->
shape
)
{
usage
(
dim
);
usage
(
dim
);
}
}
for
(
const
auto
&
dim
:
buf
->
strides
)
{
for
(
const
auto
&
dim
:
buf
->
strides
)
{
usage
(
dim
);
usage
(
dim
);
}
}
usage
(
buf
->
elem_offset
);
usage
(
buf
->
elem_offset
);
for
(
const
auto
&
buffer
:
usage
.
buffer_use_count_
)
{
for
(
const
auto
&
buffer
:
usage
.
buffer_use_count_
)
{
if
(
buffer
.
second
>=
1
)
{
if
(
buffer
.
second
>=
1
)
{
used_in_buffer_def_
.
insert
(
buffer
.
first
);
used_in_buffer_def_
.
insert
(
buffer
.
first
);
}
}
}
}
for
(
const
auto
&
buffer
:
usage
.
undefined_buffers_
)
{
for
(
const
auto
&
buffer
:
usage
.
undefined_buffers_
)
{
used_in_buffer_def_
.
insert
(
buffer
.
get
());
used_in_buffer_def_
.
insert
(
buffer
.
get
());
}
}
}
}
PrimFunc
func
;
PrimFunc
func
;
std
::
unordered_set
<
const
BufferNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
BufferNode
*>
used_in_buffer_def_
;
};
};
Visitor
visitor
(
func
);
Visitor
visitor
(
func
);
...
@@ -154,41 +155,42 @@ std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) {
...
@@ -154,41 +155,42 @@ std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) {
return
visitor
.
used_in_buffer_def_
;
return
visitor
.
used_in_buffer_def_
;
}
}
/* \brief Utility function to collect vars that should be retained. Used in
/* \brief Utility function to collect vars that should be retained. Used in Letstmt Only
* Letstmt Only
*/
*/
std
::
unordered_set
<
const
VarNode
*>
CollectVarsUsedInBufferDefinition
(
const
Stmt
&
stmt
)
{
std
::
unordered_set
<
const
VarNode
*>
CollectVarsUsedInBufferDefinition
(
const
Stmt
&
stmt
)
{
struct
Visitor
:
StmtExprVisitor
{
struct
Visitor
:
StmtExprVisitor
{
using
StmtExprVisitor
::
VisitExpr_
;
using
StmtExprVisitor
::
VisitExpr_
;
using
StmtExprVisitor
::
VisitStmt_
;
using
StmtExprVisitor
::
VisitStmt_
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitExpr_
(
op
);
StmtExprVisitor
::
VisitExpr_
(
op
);
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
VisitBuffer
(
op
->
buffer
);
VisitBuffer
(
op
->
buffer
);
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
op
);
}
}
void
VisitBuffer
(
const
Buffer
&
buf
)
{
void
VisitBuffer
(
const
Buffer
&
buf
)
{
// Collect variables that should remain defined
// Collect variables that should remain defined
VarUseDefAnalyzer
usage
(
Array
<
Var
>
{});
VarUseDefAnalyzer
usage
(
Array
<
Var
>
{});
usage
(
buf
->
data
);
usage
(
buf
->
data
);
for
(
const
auto
&
dim
:
buf
->
shape
)
{
for
(
const
auto
&
dim
:
buf
->
shape
)
{
usage
(
dim
);
usage
(
dim
);
}
}
for
(
const
auto
&
dim
:
buf
->
strides
)
{
for
(
const
auto
&
dim
:
buf
->
strides
)
{
usage
(
dim
);
usage
(
dim
);
}
}
usage
(
buf
->
elem_offset
);
usage
(
buf
->
elem_offset
);
// Track for use in LetStmtNode mutator
// Track for use in LetStmtNode mutator
for
(
const
auto
&
var
:
usage
.
undefined_
)
{
for
(
const
auto
&
var
:
usage
.
undefined_
)
{
used_in_buffer_def_
.
insert
(
var
.
get
());
used_in_buffer_def_
.
insert
(
var
.
get
());
}
}
}
}
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
};
};
Visitor
visitor
;
Visitor
visitor
;
...
@@ -197,20 +199,21 @@ std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt&
...
@@ -197,20 +199,21 @@ std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt&
}
}
class
SimplifyConfig
:
public
Attrs
{
class
SimplifyConfig
:
public
Attrs
{
public:
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS
(
SimplifyConfig
,
Attrs
,
SimplifyConfigNode
);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS
(
SimplifyConfig
,
Attrs
,
SimplifyConfigNode
);
};
};
TVM_REGISTER_NODE_TYPE
(
SimplifyConfigNode
);
TVM_REGISTER_NODE_TYPE
(
SimplifyConfigNode
);
TVM_REGISTER_PASS_CONFIG_OPTION
(
"tl.Simplify"
,
SimplifyConfig
);
TVM_REGISTER_PASS_CONFIG_OPTION
(
"tl.Simplify"
,
SimplifyConfig
);
class
StmtSimplifier
:
public
IRMutatorWithAnalyzer
{
class
StmtSimplifier
:
public
IRMutatorWithAnalyzer
{
public:
public:
static
PrimFunc
Apply
(
PrimFunc
func
,
Analyzer
*
analyzer
,
static
PrimFunc
Apply
(
PrimFunc
func
,
Analyzer
*
analyzer
,
Optional
<
SimplifyConfig
>
config_opt
=
NullOpt
)
{
Optional
<
SimplifyConfig
>
config_opt
=
NullOpt
)
{
auto
config
=
config_opt
.
value_or
(
AttrsWithDefaultValues
<
SimplifyConfig
>
());
auto
config
=
config_opt
.
value_or
(
AttrsWithDefaultValues
<
SimplifyConfig
>
());
analyzer
->
rewrite_simplify
.
SetEnabledExtensions
(
config
->
GetEnabledExtensions
());
analyzer
->
rewrite_simplify
.
SetEnabledExtensions
(
config
->
GetEnabledExtensions
());
std
::
optional
<
ControlFlowGraph
>
touch_pattern
=
std
::
nullopt
;
std
::
optional
<
ControlFlowGraph
>
touch_pattern
=
std
::
nullopt
;
if
(
config
->
propagate_knowns_to_prove_conditional
||
if
(
config
->
propagate_knowns_to_prove_conditional
||
...
@@ -218,7 +221,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -218,7 +221,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
touch_pattern
=
ControlFlowGraph
(
func
->
body
);
touch_pattern
=
ControlFlowGraph
(
func
->
body
);
}
}
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
=
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
=
CollectVarsUsedInBufferDefinition
(
func
->
body
);
CollectVarsUsedInBufferDefinition
(
func
->
body
);
StmtSimplifier
simplifier
(
analyzer
,
config
,
std
::
move
(
touch_pattern
),
StmtSimplifier
simplifier
(
analyzer
,
config
,
std
::
move
(
touch_pattern
),
std
::
move
(
used_in_buffer_def
));
std
::
move
(
used_in_buffer_def
));
...
@@ -232,41 +235,44 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -232,41 +235,44 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Array
<
Var
>
new_params
;
Array
<
Var
>
new_params
;
Map
<
Var
,
Buffer
>
new_buffer_map
;
Map
<
Var
,
Buffer
>
new_buffer_map
;
// Check whether each buffer is used
// Check whether each buffer is used
for
(
const
auto
&
var
:
func
->
params
)
{
for
(
const
auto
&
var
:
func
->
params
)
{
if
(
func
->
buffer_map
.
find
(
var
)
!=
func
->
buffer_map
.
end
())
{
if
(
func
->
buffer_map
.
find
(
var
)
!=
func
->
buffer_map
.
end
())
{
if
(
simplifier
.
used_buffers_
.
find
(
func
->
buffer_map
[
var
].
get
())
!=
simplifier
.
used_buffers_
.
end
())
{
if
(
simplifier
.
used_buffers_
.
find
(
func
->
buffer_map
[
var
].
get
())
!=
new_params
.
push_back
(
var
);
simplifier
.
used_buffers_
.
end
())
{
new_buffer_map
.
Set
(
var
,
func
->
buffer_map
[
var
]
);
new_params
.
push_back
(
var
);
}
else
{
new_buffer_map
.
Set
(
var
,
func
->
buffer_map
[
var
]);
param_updated
=
true
;
}
else
{
}
param_updated
=
true
;
}
}
}
}
}
// return func;
// return func;
if
(
param_updated
)
{
if
(
param_updated
)
{
return
PrimFunc
(
new_params
,
func
.
CopyOnWrite
()
->
body
,
func
->
ret_type
,
new_buffer_map
,
func
->
attrs
,
func
->
span
);
return
PrimFunc
(
new_params
,
func
.
CopyOnWrite
()
->
body
,
func
->
ret_type
,
new_buffer_map
,
func
->
attrs
,
func
->
span
);
}
else
{
}
else
{
return
func
;
return
func
;
}
}
}
}
private:
private:
explicit
StmtSimplifier
(
Analyzer
*
analyzer
,
SimplifyConfig
config
,
explicit
StmtSimplifier
(
std
::
optional
<
ControlFlowGraph
>
touch_pattern
,
Analyzer
*
analyzer
,
SimplifyConfig
config
,
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
)
std
::
optional
<
ControlFlowGraph
>
touch_pattern
,
:
IRMutatorWithAnalyzer
(
analyzer
),
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def
)
config_
(
config
),
:
IRMutatorWithAnalyzer
(
analyzer
),
config_
(
config
),
touch_pattern_
(
touch_pattern
),
touch_pattern_
(
touch_pattern
),
used_in_buffer_def_
(
used_in_buffer_def
)
{
used_in_buffer_def_
(
used_in_buffer_def
)
{
}
}
using
Parent
=
IRMutatorWithAnalyzer
;
using
Parent
=
IRMutatorWithAnalyzer
;
using
Parent
::
VisitExpr_
;
using
Parent
::
VisitExpr_
;
using
Parent
::
VisitStmt
;
using
Parent
::
VisitStmt
;
using
Parent
::
VisitStmt_
;
using
Parent
::
VisitStmt_
;
PrimExpr
VisitExpr
(
const
PrimExpr
&
expr
)
final
{
PrimExpr
VisitExpr
(
const
PrimExpr
&
expr
)
final
{
if
(
config_
->
propagate_knowns_to_simplify_expressions
)
{
if
(
config_
->
propagate_knowns_to_simplify_expressions
)
{
return
touch_pattern_
->
SimplifyInContext
(
expr
,
current_stmt_
.
value
(),
analyzer_
);
return
touch_pattern_
->
SimplifyInContext
(
expr
,
current_stmt_
.
value
(),
analyzer_
);
}
else
{
}
else
{
return
analyzer_
->
Simplify
(
expr
);
return
analyzer_
->
Simplify
(
expr
);
}
}
...
@@ -274,7 +280,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -274,7 +280,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Stmt
Simplify
(
Stmt
stmt
)
{
return
operator
()(
std
::
move
(
stmt
));
}
Stmt
Simplify
(
Stmt
stmt
)
{
return
operator
()(
std
::
move
(
stmt
));
}
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
override
{
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
override
{
Optional
<
Stmt
>
cache
=
this
->
current_stmt_
;
Optional
<
Stmt
>
cache
=
this
->
current_stmt_
;
this
->
current_stmt_
=
stmt
;
this
->
current_stmt_
=
stmt
;
Stmt
output
=
Parent
::
VisitStmt
(
stmt
);
Stmt
output
=
Parent
::
VisitStmt
(
stmt
);
...
@@ -282,23 +288,28 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -282,23 +288,28 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return
output
;
return
output
;
}
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
analyzer_
->
Bind
(
op
->
loop_var
,
Range
::
FromMinExtent
(
op
->
min
,
op
->
extent
));
analyzer_
->
Bind
(
op
->
loop_var
,
Range
::
FromMinExtent
(
op
->
min
,
op
->
extent
));
With
<
ConstraintContext
>
ctx1
(
analyzer_
,
op
->
loop_var
>=
op
->
min
);
With
<
ConstraintContext
>
ctx1
(
analyzer_
,
op
->
loop_var
>=
op
->
min
);
With
<
ConstraintContext
>
ctx2
(
analyzer_
,
op
->
loop_var
<
op
->
min
+
op
->
extent
);
With
<
ConstraintContext
>
ctx2
(
analyzer_
,
op
->
loop_var
<
op
->
min
+
op
->
extent
);
return
Parent
::
VisitStmt_
(
op
);
return
Parent
::
VisitStmt_
(
op
);
}
}
bool
CanInlineLetStmt
(
const
LetStmtNode
*
op
)
{
bool
CanInlineLetStmt
(
const
LetStmtNode
*
op
)
{
if
(
is_const_number
(
op
->
value
))
return
true
;
if
(
is_const_number
(
op
->
value
))
if
(
op
->
value
.
as
<
VarNode
>
())
return
true
;
return
true
;
if
(
op
->
value
.
as
<
VarNode
>
())
return
true
;
// Won't face the deep expression explosion problem as in Let expression.
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be index).
// attempt to inline as much as possible if the value integer type(can be
if
(
!
op
->
value
.
dtype
().
is_int
())
return
false
;
// index).
if
(
!
op
->
value
.
dtype
().
is_int
())
return
false
;
return
SideEffect
(
op
->
value
)
<=
CallEffectKind
::
kPure
;
return
SideEffect
(
op
->
value
)
<=
CallEffectKind
::
kPure
;
}
}
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
override
{
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
override
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
bool
can_inline
=
CanInlineLetStmt
(
op
);
bool
can_inline
=
CanInlineLetStmt
(
op
);
if
(
can_inline
)
{
if
(
can_inline
)
{
...
@@ -339,7 +350,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -339,7 +350,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
}
}
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
override
{
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
override
{
if
(
Optional
<
Bool
>
cond
=
ProveCondition
(
op
->
condition
))
{
if
(
Optional
<
Bool
>
cond
=
ProveCondition
(
op
->
condition
))
{
if
(
cond
.
value
()
->
value
)
{
if
(
cond
.
value
()
->
value
)
{
return
this
->
VisitStmt
(
op
->
then_case
);
return
this
->
VisitStmt
(
op
->
then_case
);
...
@@ -353,7 +364,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -353,7 +364,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
}
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
override
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
override
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
if
(
Optional
<
Bool
>
cond
=
ProveCondition
(
op
->
args
[
0
]))
{
if
(
Optional
<
Bool
>
cond
=
ProveCondition
(
op
->
args
[
0
]))
{
if
(
cond
.
value
()
->
value
)
{
if
(
cond
.
value
()
->
value
)
{
...
@@ -366,26 +377,27 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -366,26 +377,27 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return
Parent
::
VisitExpr_
(
op
);
return
Parent
::
VisitExpr_
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
override
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
override
{
used_vars_
.
insert
(
op
);
used_vars_
.
insert
(
op
);
return
Parent
::
VisitExpr_
(
op
);
return
Parent
::
VisitExpr_
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
override
{
auto
buffer
=
op
->
buffer
.
get
();
auto
buffer
=
op
->
buffer
.
get
();
if
(
used_buffers_
.
find
(
buffer
)
==
used_buffers_
.
end
())
{
if
(
used_buffers_
.
find
(
buffer
)
==
used_buffers_
.
end
())
{
used_buffers_
.
insert
(
buffer
);
used_buffers_
.
insert
(
buffer
);
}
}
return
Parent
::
VisitExpr_
(
op
);
return
Parent
::
VisitExpr_
(
op
);
}
}
// eliminate useless stores
// eliminate useless stores
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
override
{
BufferStore
store
=
Downcast
<
BufferStore
>
(
Parent
::
VisitStmt_
(
op
));
BufferStore
store
=
Downcast
<
BufferStore
>
(
Parent
::
VisitStmt_
(
op
));
if
(
const
BufferLoadNode
*
load
=
store
->
value
.
as
<
BufferLoadNode
>
())
{
if
(
const
BufferLoadNode
*
load
=
store
->
value
.
as
<
BufferLoadNode
>
())
{
if
(
load
->
buffer
->
data
.
same_as
(
store
->
buffer
->
data
)
&&
if
(
load
->
buffer
->
data
.
same_as
(
store
->
buffer
->
data
)
&&
ArrayDeepEqual
(
load
->
indices
,
store
->
indices
)
&&
ArrayDeepEqual
(
load
->
indices
,
store
->
indices
)
&&
tir
::
ExprDeepEqual
()(
load
->
buffer
->
elem_offset
,
store
->
buffer
->
elem_offset
)
&&
tir
::
ExprDeepEqual
()(
load
->
buffer
->
elem_offset
,
store
->
buffer
->
elem_offset
)
&&
ArrayDeepEqual
(
load
->
buffer
->
shape
,
store
->
buffer
->
shape
)
&&
ArrayDeepEqual
(
load
->
buffer
->
shape
,
store
->
buffer
->
shape
)
&&
ArrayDeepEqual
(
load
->
buffer
->
strides
,
store
->
buffer
->
strides
))
{
ArrayDeepEqual
(
load
->
buffer
->
strides
,
store
->
buffer
->
strides
))
{
return
Evaluate
(
0
);
return
Evaluate
(
0
);
...
@@ -393,13 +405,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -393,13 +405,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
auto
buffer
=
op
->
buffer
.
get
();
auto
buffer
=
op
->
buffer
.
get
();
if
(
used_buffers_
.
find
(
buffer
)
==
used_buffers_
.
end
())
{
if
(
used_buffers_
.
find
(
buffer
)
==
used_buffers_
.
end
())
{
used_buffers_
.
insert
(
buffer
);
used_buffers_
.
insert
(
buffer
);
}
}
return
std
::
move
(
store
);
return
std
::
move
(
store
);
}
}
private:
private:
bool
ArrayDeepEqual
(
const
Array
<
PrimExpr
>
&
lhs
,
const
Array
<
PrimExpr
>
&
rhs
)
{
bool
ArrayDeepEqual
(
const
Array
<
PrimExpr
>
&
lhs
,
const
Array
<
PrimExpr
>
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
return
false
;
}
}
...
@@ -420,11 +432,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -420,11 +432,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
condition
=
Substitute
(
condition
,
non_inlined_bindings_
);
condition
=
Substitute
(
condition
,
non_inlined_bindings_
);
if
(
config_
->
propagate_knowns_to_prove_conditional
)
{
if
(
config_
->
propagate_knowns_to_prove_conditional
)
{
ICHECK
(
touch_pattern_
.
has_value
());
ICHECK
(
touch_pattern_
.
has_value
());
condition
=
touch_pattern_
->
SimplifyInContext
(
condition
,
current_stmt_
.
value
(),
analyzer_
);
condition
=
touch_pattern_
->
SimplifyInContext
(
condition
,
current_stmt_
.
value
(),
analyzer_
);
}
else
{
}
else
{
condition
=
analyzer_
->
Simplify
(
condition
);
condition
=
analyzer_
->
Simplify
(
condition
);
}
}
if
(
const
int64_t
*
as_int
=
as_const_int
(
condition
))
{
if
(
const
int64_t
*
as_int
=
as_const_int
(
condition
))
{
return
Bool
(
*
as_int
);
return
Bool
(
*
as_int
);
}
else
{
}
else
{
return
NullOpt
;
return
NullOpt
;
...
@@ -436,21 +449,20 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
...
@@ -436,21 +449,20 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Map
<
Var
,
PrimExpr
>
non_inlined_bindings_
;
Map
<
Var
,
PrimExpr
>
non_inlined_bindings_
;
Optional
<
Stmt
>
current_stmt_
{
NullOpt
};
Optional
<
Stmt
>
current_stmt_
{
NullOpt
};
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
VarNode
*>
used_in_buffer_def_
;
std
::
unordered_set
<
const
VarNode
*>
used_vars_
;
std
::
unordered_set
<
const
VarNode
*>
used_vars_
;
std
::
unordered_set
<
const
BufferNode
*>
used_buffers_
;
std
::
unordered_set
<
const
BufferNode
*>
used_buffers_
;
};
};
using
namespace
tir
::
transform
;
using
namespace
tir
::
transform
;
tvm
::
transform
::
Pass
Simplify
()
{
tvm
::
transform
::
Pass
Simplify
()
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
pass_func
=
[
=
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
arith
::
Analyzer
analyzer
;
arith
::
Analyzer
analyzer
;
auto
cfg
=
ctx
->
GetConfig
<
SimplifyConfig
>
(
"tl.Simplify"
);
auto
cfg
=
ctx
->
GetConfig
<
SimplifyConfig
>
(
"tl.Simplify"
);
return
StmtSimplifier
::
Apply
(
f
,
&
analyzer
,
cfg
);
return
StmtSimplifier
::
Apply
(
f
,
&
analyzer
,
cfg
);
};
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.Simplify"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.Simplify"
,
{});
}
}
TVM_REGISTER_GLOBAL
(
"tl.transform.Simplify"
).
set_body_typed
(
Simplify
);
TVM_REGISTER_GLOBAL
(
"tl.transform.Simplify"
).
set_body_typed
(
Simplify
);
...
...
src/transform/thread_partial_sync.cc
View file @
549416f7
...
@@ -25,26 +25,28 @@ namespace tl {
...
@@ -25,26 +25,28 @@ namespace tl {
using
namespace
tir
;
using
namespace
tir
;
class
ThreadPartialSyncPlanner
:
public
StorageAccessVisitor
{
class
ThreadPartialSyncPlanner
:
public
StorageAccessVisitor
{
public:
public:
explicit
ThreadPartialSyncPlanner
(
StorageScope
sync_scope
)
:
sync_scope_
(
sync_scope
)
{}
explicit
ThreadPartialSyncPlanner
(
StorageScope
sync_scope
)
:
sync_scope_
(
sync_scope
)
{}
// The syncs inserted before each statement
// The syncs inserted before each statement
std
::
unordered_set
<
const
Object
*>
syncs_inserted_
;
std
::
unordered_set
<
const
Object
*>
syncs_inserted_
;
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs_inserted_
;
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs_inserted_
;
protected:
protected:
bool
Enabled
(
const
VarNode
*
buf
,
const
StorageScope
&
scope
)
const
final
{
bool
Enabled
(
const
VarNode
*
buf
,
const
StorageScope
&
scope
)
const
final
{
return
in_device_env
()
&&
scope
==
sync_scope_
;
return
in_device_env
()
&&
scope
==
sync_scope_
;
}
}
// Plan the sync
// Plan the sync
std
::
vector
<
AccessEntry
>
Summarize
(
std
::
vector
<
StmtEntry
>
seq
,
const
ForNode
*
loop
)
final
{
std
::
vector
<
AccessEntry
>
Summarize
(
std
::
vector
<
StmtEntry
>
seq
,
const
ForNode
*
loop
)
final
{
// Redirect all "shared.dyn" buffer access to the same buffer var
// Redirect all "shared.dyn" buffer access to the same buffer var
// so that the accesses can be planned together.
// so that the accesses can be planned together.
Var
shared_dyn_buf
;
Var
shared_dyn_buf
;
for
(
StmtEntry
&
entry
:
seq
)
{
for
(
StmtEntry
&
entry
:
seq
)
{
for
(
AccessEntry
&
access
:
entry
.
access
)
{
for
(
AccessEntry
&
access
:
entry
.
access
)
{
if
(
access
.
scope
.
rank
==
StorageRank
::
kShared
&&
access
.
scope
.
tag
==
".dyn"
&&
if
(
access
.
scope
.
rank
==
StorageRank
::
kShared
&&
access
.
buffer
.
defined
())
{
access
.
scope
.
tag
==
".dyn"
&&
access
.
buffer
.
defined
())
{
if
(
!
shared_dyn_buf
.
defined
())
{
if
(
!
shared_dyn_buf
.
defined
())
{
shared_dyn_buf
=
access
.
buffer
;
shared_dyn_buf
=
access
.
buffer
;
}
else
{
}
else
{
...
@@ -60,7 +62,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -60,7 +62,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// if it is a loop, rotate two times to consider effect of loop.
// if it is a loop, rotate two times to consider effect of loop.
// simulation based approach to find dependencies
// simulation based approach to find dependencies
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
const
StmtEntry
&
s
=
seq
[
i
];
// check if sync before statement is needed.
// check if sync before statement is needed.
bool
sync_before_stmt
=
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
);
bool
sync_before_stmt
=
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
);
// Apply the syncs added already.
// Apply the syncs added already.
...
@@ -68,7 +70,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -68,7 +70,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
reads
.
clear
();
reads
.
clear
();
writes
.
clear
();
writes
.
clear
();
}
}
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
FindConflict
(
writes
,
acc
,
false
))
{
if
(
FindConflict
(
writes
,
acc
,
false
))
{
sync_before_stmt
=
true
;
sync_before_stmt
=
true
;
...
@@ -90,7 +92,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -90,7 +92,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
writes
.
clear
();
writes
.
clear
();
}
}
// Add the read/write of current statement
// Add the read/write of current statement
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
acc
.
type
==
kRead
)
{
reads
.
push_back
(
acc
);
reads
.
push_back
(
acc
);
}
else
if
(
acc
.
type
==
kWrite
)
{
}
else
if
(
acc
.
type
==
kWrite
)
{
...
@@ -106,11 +108,13 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -106,11 +108,13 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
}
if
(
loop
!=
nullptr
)
{
if
(
loop
!=
nullptr
)
{
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
const
StmtEntry
&
s
=
seq
[
i
];
if
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
)
break
;
if
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
)
if
(
reads
.
empty
()
&&
writes
.
empty
())
break
;
break
;
if
(
reads
.
empty
()
&&
writes
.
empty
())
break
;
bool
sync_before_stmt
=
false
;
bool
sync_before_stmt
=
false
;
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
FindConflict
(
writes
,
acc
,
true
))
{
if
(
FindConflict
(
writes
,
acc
,
true
))
{
sync_before_stmt
=
true
;
sync_before_stmt
=
true
;
...
@@ -141,7 +145,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -141,7 +145,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
esync
.
type
=
kSync
;
esync
.
type
=
kSync
;
esync
.
scope
=
sync_scope_
;
esync
.
scope
=
sync_scope_
;
for
(
const
StmtEntry
&
s
:
seq
)
{
for
(
const
StmtEntry
&
s
:
seq
)
{
if
(
syncs_inserted_
.
count
(
s
.
stmt
))
{
if
(
syncs_inserted_
.
count
(
s
.
stmt
))
{
if
(
sync_count
!=
0
)
{
if
(
sync_count
!=
0
)
{
tail
.
clear
();
tail
.
clear
();
...
@@ -150,7 +154,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -150,7 +154,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
}
++
sync_count
;
++
sync_count
;
}
}
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kSync
)
{
if
(
acc
.
type
==
kSync
)
{
if
(
sync_count
!=
0
)
{
if
(
sync_count
!=
0
)
{
tail
.
clear
();
tail
.
clear
();
...
@@ -170,18 +174,18 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -170,18 +174,18 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
head
.
insert
(
head
.
end
(),
tail
.
begin
(),
tail
.
end
());
head
.
insert
(
head
.
end
(),
tail
.
begin
(),
tail
.
end
());
if
(
loop
!=
nullptr
)
{
if
(
loop
!=
nullptr
)
{
// clear double buffer flag after a loop is finished.
// clear double buffer flag after a loop is finished.
for
(
AccessEntry
&
e
:
head
)
{
for
(
AccessEntry
&
e
:
head
)
{
e
.
double_buffer_write
=
false
;
e
.
double_buffer_write
=
false
;
}
}
}
}
return
head
;
return
head
;
}
}
private:
private:
// find conflicting entry in vec.
// find conflicting entry in vec.
bool
FindConflict
(
const
std
::
vector
<
AccessEntry
>
&
prev
,
const
AccessEntry
&
curr
,
bool
FindConflict
(
const
std
::
vector
<
AccessEntry
>
&
prev
,
bool
loop_carry
)
{
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
for
(
const
AccessEntry
&
x
:
prev
)
{
for
(
const
AccessEntry
&
x
:
prev
)
{
if
(
FindConflict
(
x
,
curr
,
loop_carry
))
{
if
(
FindConflict
(
x
,
curr
,
loop_carry
))
{
return
true
;
return
true
;
}
}
...
@@ -189,7 +193,8 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -189,7 +193,8 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return
false
;
return
false
;
}
}
bool
FindConflict
(
const
AccessEntry
&
prev
,
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
bool
FindConflict
(
const
AccessEntry
&
prev
,
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
// Access to different buffers does not conflict.
// Access to different buffers does not conflict.
if
(
!
prev
.
buffer
.
same_as
(
curr
.
buffer
))
{
if
(
!
prev
.
buffer
.
same_as
(
curr
.
buffer
))
{
return
false
;
return
false
;
...
@@ -202,21 +207,21 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -202,21 +207,21 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// Even if access has the same index, those indices need to
// Even if access has the same index, those indices need to
// depend on the innermost thread id to avoid race condition
// depend on the innermost thread id to avoid race condition
bool
depends_on_thread_index
=
true
;
bool
depends_on_thread_index
=
true
;
const
VarNode
*
thread_index_var
=
nullptr
;
const
VarNode
*
thread_index_var
=
nullptr
;
if
(
!
curr
.
threads
.
empty
())
{
if
(
!
curr
.
threads
.
empty
())
{
thread_index_var
=
curr
.
threads
.
back
()
->
var
.
get
();
thread_index_var
=
curr
.
threads
.
back
()
->
var
.
get
();
}
}
for
(
size_t
i
=
0
;
i
<
prev
.
touched
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
prev
.
touched
.
size
();
i
++
)
{
const
auto
&
prev_intset
=
prev
.
touched
[
i
];
const
auto
&
prev_intset
=
prev
.
touched
[
i
];
const
auto
&
curr_intset
=
curr
.
touched
[
i
];
const
auto
&
curr_intset
=
curr
.
touched
[
i
];
if
(
prev_intset
.
IsSinglePoint
()
&&
curr_intset
.
IsSinglePoint
())
{
if
(
prev_intset
.
IsSinglePoint
()
&&
curr_intset
.
IsSinglePoint
())
{
PrimExpr
prev_index
=
prev_intset
.
PointValue
();
PrimExpr
prev_index
=
prev_intset
.
PointValue
();
PrimExpr
curr_index
=
curr_intset
.
PointValue
();
PrimExpr
curr_index
=
curr_intset
.
PointValue
();
has_same_index
=
ExprDeepEqual
()(
prev_index
,
curr_index
);
has_same_index
=
ExprDeepEqual
()(
prev_index
,
curr_index
);
if
(
thread_index_var
!=
nullptr
)
{
if
(
thread_index_var
!=
nullptr
)
{
auto
f_uses_thread_index
=
[
=
](
const
tvm
::
tir
::
VarNode
*
parameter
)
{
auto
f_uses_thread_index
=
[
=
](
const
tvm
::
tir
::
VarNode
*
parameter
)
{
return
parameter
==
thread_index_var
;
return
parameter
==
thread_index_var
;
};
};
depends_on_thread_index
=
depends_on_thread_index
&&
depends_on_thread_index
=
depends_on_thread_index
&&
...
@@ -246,7 +251,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -246,7 +251,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return
true
;
return
true
;
}
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
"kWarpSpecializationScope"
)
{
if
(
op
->
attr_key
==
"kWarpSpecializationScope"
)
{
IfThenElse
body
=
Downcast
<
IfThenElse
>
(
op
->
body
);
IfThenElse
body
=
Downcast
<
IfThenElse
>
(
op
->
body
);
auto
partitions
=
Downcast
<
Array
<
IntImm
>>
(
op
->
node
);
auto
partitions
=
Downcast
<
Array
<
IntImm
>>
(
op
->
node
);
...
@@ -273,27 +278,31 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -273,27 +278,31 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
}
}
}
void
insert_syncs
(
const
Object
*
obj
)
{
void
insert_syncs
(
const
Object
*
obj
)
{
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside
if
(
syncs_inserted_
.
count
(
obj
))
return
;
// condition";
if
(
syncs_inserted_
.
count
(
obj
))
return
;
if
(
num_partial_threads_
.
defined
())
{
if
(
num_partial_threads_
.
defined
())
{
syncs_inserted_
.
insert
(
obj
);
syncs_inserted_
.
insert
(
obj
);
partial_syncs_inserted_
[
obj
]
=
static_cast
<
int
>
(
num_partial_threads_
.
value
()
->
value
);
partial_syncs_inserted_
[
obj
]
=
static_cast
<
int
>
(
num_partial_threads_
.
value
()
->
value
);
}
else
{
}
else
{
syncs_inserted_
.
insert
(
obj
);
syncs_inserted_
.
insert
(
obj
);
}
}
}
}
private:
private:
Optional
<
IntImm
>
num_partial_threads_
;
Optional
<
IntImm
>
num_partial_threads_
;
// synchronization scope
// synchronization scope
StorageScope
sync_scope_
;
StorageScope
sync_scope_
;
};
};
// There are cases where necessary syncthreads is not inserted by ThreadPartialSyncInserter.
// There are cases where necessary syncthreads is not inserted by
// For example, syncthreads is needed after async_wait_queue in the second loop below,
// ThreadPartialSyncInserter. For example, syncthreads is needed after
// but since ThreadPartialSyncInserter is not aware of the asynchronous semantics, it cannot tell
// async_wait_queue in the second loop below, but since
// that the syncthreads is needed there.
// ThreadPartialSyncInserter is not aware of the asynchronous semantics, it
// cannot tell that the syncthreads is needed there.
//
//
// // Pipeline prologue
// // Pipeline prologue
// for i in range(125):
// for i in range(125):
...
@@ -307,21 +316,23 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
...
@@ -307,21 +316,23 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// async_wait_queue(0, 2 - i):
// async_wait_queue(0, 2 - i):
// local[...] = shared[(i + 125) % 4]
// local[...] = shared[(i + 125) % 4]
class
ThreadPartialSyncInserter
:
public
StmtExprMutator
{
class
ThreadPartialSyncInserter
:
public
StmtExprMutator
{
public:
public:
ThreadPartialSyncInserter
(
StorageScope
sync_scope
,
const
std
::
unordered_set
<
const
Object
*>&
syncs
,
ThreadPartialSyncInserter
(
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs
)
StorageScope
sync_scope
,
const
std
::
unordered_set
<
const
Object
*>
&
syncs
,
std
::
unordered_map
<
const
Object
*
,
int
>
partial_syncs
)
:
sync_scope_
(
sync_scope
),
syncs_
(
syncs
),
partial_syncs_
(
partial_syncs
)
{}
:
sync_scope_
(
sync_scope
),
syncs_
(
syncs
),
partial_syncs_
(
partial_syncs
)
{}
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
final
{
Stmt
VisitStmt
(
const
Stmt
&
stmt
)
final
{
if
(
syncs_
.
size
()
==
0
)
return
stmt
;
if
(
syncs_
.
size
()
==
0
)
return
stmt
;
if
(
syncs_
.
count
(
stmt
.
get
()))
{
if
(
syncs_
.
count
(
stmt
.
get
()))
{
Stmt
barrier
;
Stmt
barrier
;
if
(
partial_syncs_
.
count
(
stmt
.
get
()))
{
if
(
partial_syncs_
.
count
(
stmt
.
get
()))
{
auto
iter
=
partial_syncs_
.
find
(
stmt
.
get
());
auto
iter
=
partial_syncs_
.
find
(
stmt
.
get
());
ICHECK
(
sync_scope_
.
rank
==
StorageRank
::
kShared
);
ICHECK
(
sync_scope_
.
rank
==
StorageRank
::
kShared
);
barrier
=
Evaluate
(
Call
(
DataType
::
Int
(
32
),
tl
::
SyncThreadsPartialOp
(),
{
iter
->
second
}));
barrier
=
Evaluate
(
Call
(
DataType
::
Int
(
32
),
tl
::
SyncThreadsPartialOp
(),
{
iter
->
second
}));
}
else
{
}
else
{
return
StmtExprMutator
::
VisitStmt
(
stmt
);
return
StmtExprMutator
::
VisitStmt
(
stmt
);
}
}
...
@@ -334,11 +345,11 @@ class ThreadPartialSyncInserter : public StmtExprMutator {
...
@@ -334,11 +345,11 @@ class ThreadPartialSyncInserter : public StmtExprMutator {
}
}
}
}
private:
private:
// data structure.
// data structure.
StorageScope
sync_scope_
;
StorageScope
sync_scope_
;
const
std
::
unordered_set
<
const
Object
*>
&
syncs_
;
const
std
::
unordered_set
<
const
Object
*>
&
syncs_
;
const
std
::
unordered_map
<
const
Object
*
,
int
>
&
partial_syncs_
;
const
std
::
unordered_map
<
const
Object
*
,
int
>
&
partial_syncs_
;
};
};
Stmt
ThreadPartialSync
(
Stmt
stmt
,
std
::
string
storage_scope
)
{
Stmt
ThreadPartialSync
(
Stmt
stmt
,
std
::
string
storage_scope
)
{
...
@@ -346,7 +357,8 @@ Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
...
@@ -346,7 +357,8 @@ Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
ThreadPartialSyncPlanner
planner
(
sync_scope
);
ThreadPartialSyncPlanner
planner
(
sync_scope
);
planner
(
stmt
);
planner
(
stmt
);
return
ThreadPartialSyncInserter
(
sync_scope
,
planner
.
syncs_inserted_
,
return
ThreadPartialSyncInserter
(
sync_scope
,
planner
.
syncs_inserted_
,
planner
.
partial_syncs_inserted_
)(
std
::
move
(
stmt
));
planner
.
partial_syncs_inserted_
)(
std
::
move
(
stmt
));
}
}
using
namespace
tir
::
transform
;
using
namespace
tir
::
transform
;
...
@@ -355,15 +367,16 @@ namespace transform {
...
@@ -355,15 +367,16 @@ namespace transform {
Pass
ThreadPartialSync
(
String
storage_scope
)
{
Pass
ThreadPartialSync
(
String
storage_scope
)
{
auto
pass_func
=
[
storage_scope
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
pass_func
=
[
storage_scope
](
PrimFunc
f
,
IRModule
m
,
PassContext
ctx
)
{
auto
*
n
=
f
.
CopyOnWrite
();
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
tl
::
ThreadPartialSync
(
std
::
move
(
n
->
body
),
storage_scope
);
n
->
body
=
tl
::
ThreadPartialSync
(
std
::
move
(
n
->
body
),
storage_scope
);
return
f
;
return
f
;
};
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ThreadPartialSync"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ThreadPartialSync"
,
{});
}
}
TVM_REGISTER_GLOBAL
(
"tl.transform.ThreadPartialSync"
).
set_body_typed
(
ThreadPartialSync
);
TVM_REGISTER_GLOBAL
(
"tl.transform.ThreadPartialSync"
)
.
set_body_typed
(
ThreadPartialSync
);
}
// namespace transform
}
// namespace transform
}
// namespace t
ir
}
// namespace t
l
}
// namespace tvm
}
// namespace tvm
src/transform/warp_specialized_rewriter.cc
View file @
549416f7
...
@@ -38,22 +38,23 @@ using namespace tir;
...
@@ -38,22 +38,23 @@ using namespace tir;
enum
class
Role
{
kConsumer
,
kProducer
,
kBoth
};
enum
class
Role
{
kConsumer
,
kProducer
,
kBoth
};
class
WarpSpecializedRoleMarker
:
public
StmtVisitor
{
class
WarpSpecializedRoleMarker
:
public
StmtVisitor
{
public:
public:
WarpSpecializedRoleMarker
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
)
WarpSpecializedRoleMarker
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
)
:
buffer_data_to_buffer_
(
buffer_data_to_buffer
)
{}
:
buffer_data_to_buffer_
(
buffer_data_to_buffer
)
{}
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
Role
GetRole
(
const
StmtNode
*
stmt
)
const
{
auto
it
=
map_
.
find
(
stmt
);
auto
it
=
map_
.
find
(
stmt
);
ICHECK
(
it
!=
map_
.
end
());
ICHECK
(
it
!=
map_
.
end
());
return
it
->
second
;
return
it
->
second
;
}
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
Role
GetRole
(
const
Stmt
&
stmt
)
const
{
return
GetRole
(
stmt
.
get
());
}
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
Role
role
=
Role
::
kConsumer
;
Role
role
=
Role
::
kConsumer
;
if
(
auto
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
auto
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
role
=
Role
::
kProducer
;
role
=
Role
::
kProducer
;
has_bulk_copy_
=
true
;
has_bulk_copy_
=
true
;
}
}
...
@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
...
@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
bool
is_shared_store
=
op
->
buffer
.
scope
()
==
"shared.dyn"
||
op
->
buffer
.
scope
()
==
"shared"
;
if
(
!
is_shared_store
)
{
if
(
!
is_shared_store
)
{
SetRole
(
op
,
Role
::
kConsumer
);
SetRole
(
op
,
Role
::
kConsumer
);
return
;
return
;
...
@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
...
@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
break
;
break
;
}
}
}
}
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
if
(
role
==
Role
::
kProducer
)
has_simt_copy_
=
true
;
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
seq
[
0
]);
auto
role
=
GetRole
(
op
->
seq
[
0
]);
for
(
auto
stmt
:
op
->
seq
)
{
for
(
auto
stmt
:
op
->
seq
)
{
...
@@ -96,41 +99,41 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
...
@@ -96,41 +99,41 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
void
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
auto
role
=
GetRole
(
op
->
then_case
);
auto
role
=
GetRole
(
op
->
then_case
);
if
(
op
->
else_case
.
defined
())
{
if
(
op
->
else_case
.
defined
())
{
auto
role_else
=
GetRole
(
op
->
else_case
.
value
());
auto
role_else
=
GetRole
(
op
->
else_case
.
value
());
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
if
(
role
!=
role_else
)
role
=
Role
::
kBoth
;
}
}
SetRole
(
op
,
role
);
SetRole
(
op
,
role
);
}
}
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
void
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
block
));
SetRole
(
op
,
GetRole
(
op
->
block
));
}
}
template
<
class
NodeType
>
template
<
class
NodeType
>
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
void
HandleBodyStmt
(
const
NodeType
*
op
)
{
StmtVisitor
::
VisitStmt_
(
op
);
StmtVisitor
::
VisitStmt_
(
op
);
SetRole
(
op
,
GetRole
(
op
->
body
));
SetRole
(
op
,
GetRole
(
op
->
body
));
}
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
HandleBodyStmt
(
op
);
}
bool
HasProducer
()
{
return
has_simt_copy_
||
has_bulk_copy_
;
}
bool
HasProducer
()
{
return
has_simt_copy_
||
has_bulk_copy_
;
}
bool
HasSimtCopy
()
{
return
has_simt_copy_
;
}
bool
HasSimtCopy
()
{
return
has_simt_copy_
;
}
private:
private:
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
void
SetRole
(
const
StmtNode
*
stmt
,
Role
role
)
{
map_
[
stmt
]
=
role
;
}
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
std
::
unordered_map
<
const
StmtNode
*
,
Role
>
map_
;
bool
has_simt_copy_
=
false
;
bool
has_simt_copy_
=
false
;
bool
has_bulk_copy_
=
false
;
bool
has_bulk_copy_
=
false
;
};
};
...
@@ -140,23 +143,26 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
...
@@ -140,23 +143,26 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
}
}
static
Stmt
makeExpectTX
(
PrimExpr
barrier_id
,
PrimExpr
bytes
)
{
static
Stmt
makeExpectTX
(
PrimExpr
barrier_id
,
PrimExpr
bytes
)
{
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierExpectTX
(),
{
makeGetBarrier
(
barrier_id
),
bytes
});
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierExpectTX
(),
{
makeGetBarrier
(
barrier_id
),
bytes
});
return
Evaluate
(
call
);
return
Evaluate
(
call
);
}
}
static
Stmt
makeArriveBarrier
(
PrimExpr
barrier_id
)
{
static
Stmt
makeArriveBarrier
(
PrimExpr
barrier_id
)
{
auto
call
=
Call
(
DataType
::
Handle
(),
builtin
::
ptx_arrive_barrier
(),
{
makeGetBarrier
(
barrier_id
)});
auto
call
=
Call
(
DataType
::
Handle
(),
builtin
::
ptx_arrive_barrier
(),
{
makeGetBarrier
(
barrier_id
)});
return
Evaluate
(
call
);
return
Evaluate
(
call
);
}
}
static
Stmt
makeCpAsyncBarrier
(
PrimExpr
barrier_id
)
{
static
Stmt
makeCpAsyncBarrier
(
PrimExpr
barrier_id
)
{
auto
call
=
auto
call
=
Call
(
DataType
::
Handle
(),
builtin
::
ptx_cp_async_barrier
(),
Call
(
DataType
::
Handle
(),
builtin
::
ptx_cp_async_barrier
(),
{
makeGetBarrier
(
barrier_id
)});
{
makeGetBarrier
(
barrier_id
)});
return
Evaluate
(
call
);
return
Evaluate
(
call
);
}
}
static
Stmt
makeParityWait
(
PrimExpr
barrier_id
,
PrimExpr
parity
)
{
static
Stmt
makeParityWait
(
PrimExpr
barrier_id
,
PrimExpr
parity
)
{
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierWaitParity
(),
{
makeGetBarrier
(
barrier_id
),
parity
});
auto
call
=
Call
(
DataType
::
Handle
(),
MBarrierWaitParity
(),
{
makeGetBarrier
(
barrier_id
),
parity
});
return
Evaluate
(
call
);
return
Evaluate
(
call
);
}
}
...
@@ -177,7 +183,7 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
...
@@ -177,7 +183,7 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
// }
// }
class
ProducerTraitsCollector
:
public
StmtExprVisitor
{
class
ProducerTraitsCollector
:
public
StmtExprVisitor
{
public:
public:
ProducerTraitsCollector
()
{
Clear
();
}
ProducerTraitsCollector
()
{
Clear
();
}
void
Clear
()
{
void
Clear
()
{
...
@@ -192,8 +198,8 @@ class ProducerTraitsCollector : public StmtExprVisitor {
...
@@ -192,8 +198,8 @@ class ProducerTraitsCollector : public StmtExprVisitor {
PrimExpr
BulkCopyBytes
()
{
return
bulk_copy_bytes
;
}
PrimExpr
BulkCopyBytes
()
{
return
bulk_copy_bytes
;
}
private:
private:
void
VisitExpr_
(
const
CallNode
*
call
)
final
{
void
VisitExpr_
(
const
CallNode
*
call
)
final
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
Call
access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
Call
access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
ICHECK
(
access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
ICHECK
(
access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
...
@@ -203,14 +209,14 @@ class ProducerTraitsCollector : public StmtExprVisitor {
...
@@ -203,14 +209,14 @@ class ProducerTraitsCollector : public StmtExprVisitor {
StmtExprVisitor
::
VisitExpr_
(
call
);
StmtExprVisitor
::
VisitExpr_
(
call
);
}
}
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
PrimExpr
old_loop_evtents
=
loop_extents
;
PrimExpr
old_loop_evtents
=
loop_extents
;
loop_extents
*=
op
->
extent
;
loop_extents
*=
op
->
extent
;
StmtExprVisitor
::
VisitStmt_
(
op
);
StmtExprVisitor
::
VisitStmt_
(
op
);
loop_extents
=
old_loop_evtents
;
loop_extents
=
old_loop_evtents
;
}
}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
has_simt_copy
=
true
;
has_simt_copy
=
true
;
StmtExprVisitor
::
VisitExpr_
(
op
);
StmtExprVisitor
::
VisitExpr_
(
op
);
}
}
...
@@ -222,15 +228,15 @@ class ProducerTraitsCollector : public StmtExprVisitor {
...
@@ -222,15 +228,15 @@ class ProducerTraitsCollector : public StmtExprVisitor {
// Rewrite the producer Stmt to use the correct barrier index
// Rewrite the producer Stmt to use the correct barrier index
class
MbarrierRewriter
:
public
StmtExprMutator
{
class
MbarrierRewriter
:
public
StmtExprMutator
{
public:
public:
static
Stmt
Rewrite
(
Stmt
stmt
,
PrimExpr
barrier_id
)
{
static
Stmt
Rewrite
(
Stmt
stmt
,
PrimExpr
barrier_id
)
{
MbarrierRewriter
rewriter
;
MbarrierRewriter
rewriter
;
rewriter
.
producer_barrier_idx_
=
barrier_id
;
rewriter
.
producer_barrier_idx_
=
barrier_id
;
return
rewriter
(
stmt
);
return
rewriter
(
stmt
);
}
}
private:
private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
auto
call
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
auto
call
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
if
(
call
->
op
.
same_as
(
TMALoadOp
())
||
call
->
op
.
same_as
(
TMALoadIm2ColOp
()))
{
Call
access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
Call
access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
...
@@ -242,19 +248,18 @@ class MbarrierRewriter : public StmtExprMutator {
...
@@ -242,19 +248,18 @@ class MbarrierRewriter : public StmtExprMutator {
PrimExpr
producer_barrier_idx_
;
PrimExpr
producer_barrier_idx_
;
};
};
class
ThreadIdxRewriter
:
public
StmtExprMutator
{
class
ThreadIdxRewriter
:
public
StmtExprMutator
{
public:
public:
static
Stmt
Rewrite
(
Stmt
stmt
,
Var
thread_var
,
PrimExpr
replaced
)
{
static
Stmt
Rewrite
(
Stmt
stmt
,
Var
thread_var
,
PrimExpr
replaced
)
{
auto
rewriter
=
ThreadIdxRewriter
(
thread_var
,
replaced
);
auto
rewriter
=
ThreadIdxRewriter
(
thread_var
,
replaced
);
return
rewriter
(
stmt
);
return
rewriter
(
stmt
);
}
}
private:
private:
ThreadIdxRewriter
(
Var
thread_var
,
PrimExpr
replaced
)
ThreadIdxRewriter
(
Var
thread_var
,
PrimExpr
replaced
)
:
thread_var_
(
thread_var
),
replaced_
(
replaced
)
{}
:
thread_var_
(
thread_var
),
replaced_
(
replaced
)
{}
PrimExpr
VisitExpr_
(
const
VarNode
*
var
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
var
)
final
{
if
(
var
==
thread_var_
.
get
())
{
if
(
var
==
thread_var_
.
get
())
{
return
replaced_
;
return
replaced_
;
}
else
{
}
else
{
...
@@ -266,9 +271,12 @@ class ThreadIdxRewriter : public StmtExprMutator {
...
@@ -266,9 +271,12 @@ class ThreadIdxRewriter : public StmtExprMutator {
PrimExpr
replaced_
;
PrimExpr
replaced_
;
};
};
Block
MakeGroupBlock
(
const
Stmt
&
stmt
,
const
Map
<
String
,
ObjectRef
>&
annotations
)
{
Block
MakeGroupBlock
(
const
Stmt
&
stmt
,
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
,
const
Map
<
String
,
ObjectRef
>
&
annotations
)
{
/*init=*/
{},
/*alloc_buffers=*/
{},
/*match_buffers=*/
{},
/*annotations=*/
annotations
);
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
stmt
,
/*init=*/
{},
/*alloc_buffers=*/
{},
/*match_buffers=*/
{},
/*annotations=*/
annotations
);
return
block
;
return
block
;
}
}
...
@@ -280,11 +288,8 @@ struct PipelineInfo {
...
@@ -280,11 +288,8 @@ struct PipelineInfo {
std
::
vector
<
OpInfo
>
op_infos
;
std
::
vector
<
OpInfo
>
op_infos
;
PipelineInfo
()
=
default
;
PipelineInfo
()
=
default
;
PipelineInfo
(
PipelineInfo
(
Array
<
Array
<
Integer
>>
group_info
,
Array
<
Integer
>
order_info
,
Array
<
Array
<
Integer
>>
group_info
,
Array
<
Integer
>
stage_info
)
{
Array
<
Integer
>
order_info
,
Array
<
Integer
>
stage_info
)
{
int
n
=
static_cast
<
int
>
(
group_info
.
size
());
int
n
=
static_cast
<
int
>
(
group_info
.
size
());
ICHECK
(
n
==
static_cast
<
int
>
(
order_info
.
size
()));
ICHECK
(
n
==
static_cast
<
int
>
(
order_info
.
size
()));
ICHECK
(
n
==
static_cast
<
int
>
(
stage_info
.
size
()));
ICHECK
(
n
==
static_cast
<
int
>
(
stage_info
.
size
()));
...
@@ -301,7 +306,7 @@ struct PipelineInfo {
...
@@ -301,7 +306,7 @@ struct PipelineInfo {
}
}
}
}
PipelineInfo
(
const
PipelineInfo
&
other
)
{
PipelineInfo
(
const
PipelineInfo
&
other
)
{
for
(
auto
op_info
:
other
.
op_infos
)
{
for
(
auto
op_info
:
other
.
op_infos
)
{
op_infos
.
push_back
(
op_info
);
op_infos
.
push_back
(
op_info
);
}
}
...
@@ -364,18 +369,19 @@ struct PipelineInfo {
...
@@ -364,18 +369,19 @@ struct PipelineInfo {
void
PrintPipelineInfo
()
{
void
PrintPipelineInfo
()
{
std
::
cout
<<
"Print op_infos:"
<<
std
::
endl
;
std
::
cout
<<
"Print op_infos:"
<<
std
::
endl
;
for
(
size_t
i
=
0
;
i
<
op_infos
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op_infos
.
size
();
i
++
)
{
std
::
cout
<<
i
<<
" "
<<
op_infos
[
i
].
group_size
<<
" "
<<
op_infos
[
i
].
order
<<
" "
<<
op_infos
[
i
].
stage
<<
std
::
endl
;
std
::
cout
<<
i
<<
" "
<<
op_infos
[
i
].
group_size
<<
" "
<<
op_infos
[
i
].
order
<<
" "
<<
op_infos
[
i
].
stage
<<
std
::
endl
;
}
}
std
::
cout
<<
"End of print"
<<
std
::
endl
;
std
::
cout
<<
"End of print"
<<
std
::
endl
;
}
}
};
};
class
GroupOpRewriter
:
public
StmtExprMutator
{
class
GroupOpRewriter
:
public
StmtExprMutator
{
public:
public:
GroupOpRewriter
(
PipelineInfo
pipeline_info
)
:
pipeline_info_
(
pipeline_info
)
{}
GroupOpRewriter
(
PipelineInfo
pipeline_info
)
:
pipeline_info_
(
pipeline_info
)
{}
private:
private:
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Map
<
String
,
ObjectRef
>
annotations
;
Map
<
String
,
ObjectRef
>
annotations
;
annotations
.
Set
(
String
(
"stmt_group"
),
Integer
(
1
));
annotations
.
Set
(
String
(
"stmt_group"
),
Integer
(
1
));
auto
original_node
=
(
op
->
body
).
as
<
SeqStmtNode
>
();
auto
original_node
=
(
op
->
body
).
as
<
SeqStmtNode
>
();
...
@@ -385,19 +391,24 @@ class GroupOpRewriter : public StmtExprMutator {
...
@@ -385,19 +391,24 @@ class GroupOpRewriter : public StmtExprMutator {
Array
<
Stmt
>
new_body
;
Array
<
Stmt
>
new_body
;
int
cur_id
=
0
;
int
cur_id
=
0
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
.
size
());
i
++
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
.
size
());
i
++
)
{
if
(
pipeline_info_
.
op_infos
[
i
].
group_size
==
0
)
continue
;
if
(
pipeline_info_
.
op_infos
[
i
].
group_size
==
0
)
continue
;
Array
<
Stmt
>
block_stmt
;
Array
<
Stmt
>
block_stmt
;
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
[
i
].
group_size
);
j
++
)
{
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
[
i
].
group_size
);
j
++
)
{
// ICHECK(group_info_[i][j].as<IntImmNode>());
// ICHECK(group_info_[i][j].as<IntImmNode>());
// int index = static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
// int index =
// static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
ICHECK
(
original_node
->
seq
[
cur_id
].
as
<
BlockNode
>
());
ICHECK
(
original_node
->
seq
[
cur_id
].
as
<
BlockNode
>
());
auto
block
=
original_node
->
seq
[
cur_id
].
as
<
BlockNode
>
();
auto
block
=
original_node
->
seq
[
cur_id
].
as
<
BlockNode
>
();
// TODO: handle nested seqstmt
// TODO: handle nested seqstmt
block_stmt
.
push_back
(
block
->
body
);
block_stmt
.
push_back
(
block
->
body
);
cur_id
++
;
cur_id
++
;
}
}
new_body
.
push_back
(
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
}
}
Array
<
Integer
>
order_anno
;
Array
<
Integer
>
order_anno
;
Array
<
Integer
>
stage_anno
;
Array
<
Integer
>
stage_anno
;
...
@@ -409,24 +420,26 @@ class GroupOpRewriter : public StmtExprMutator {
...
@@ -409,24 +420,26 @@ class GroupOpRewriter : public StmtExprMutator {
for_annotations
.
erase
(
"tl_pipeline_group"
);
for_annotations
.
erase
(
"tl_pipeline_group"
);
for_annotations
.
Set
(
"software_pipeline_order"
,
order_anno
);
for_annotations
.
Set
(
"software_pipeline_order"
,
order_anno
);
for_annotations
.
Set
(
"software_pipeline_stage"
,
stage_anno
);
for_annotations
.
Set
(
"software_pipeline_stage"
,
stage_anno
);
For
new_for
=
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
new_body
.
size
()
==
1
?
new_body
[
0
]
:
SeqStmt
(
std
::
move
(
new_body
)),
op
->
thread_binding
,
for_annotations
);
For
new_for
=
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
op
->
kind
,
new_body
.
size
()
==
1
?
new_body
[
0
]
:
SeqStmt
(
std
::
move
(
new_body
)),
op
->
thread_binding
,
for_annotations
);
return
new_for
;
return
new_for
;
}
}
PipelineInfo
pipeline_info_
;
PipelineInfo
pipeline_info_
;
};
};
class
WSCodeEmitter
:
public
StmtMutator
{
class
WSCodeEmitter
:
public
StmtMutator
{
public:
public:
WSCodeEmitter
(
bool
is_emitting_producer
,
IterVar
thread_iv
,
WSCodeEmitter
(
bool
is_emitting_producer
,
IterVar
thread_iv
,
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
WarpSpecializedRoleMarker
&
marker
)
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
WarpSpecializedRoleMarker
&
marker
)
:
is_emitting_producer_
(
is_emitting_producer
),
:
is_emitting_producer_
(
is_emitting_producer
),
buffer_data_to_buffer_
(
buffer_data_to_buffer
),
buffer_data_to_buffer_
(
buffer_data_to_buffer
),
marker_
(
marker
),
marker_
(
marker
),
thread_var_
(
thread_iv
->
var
)
{}
thread_var_
(
thread_iv
->
var
)
{}
private:
private:
template
<
typename
NodeType
>
template
<
typename
NodeType
>
Stmt
FilterByRole
(
const
NodeType
*
op
)
{
Stmt
FilterByRole
(
const
NodeType
*
op
)
{
Role
role
=
marker_
.
GetRole
(
op
);
Role
role
=
marker_
.
GetRole
(
op
);
if
(
role
==
Role
::
kBoth
)
if
(
role
==
Role
::
kBoth
)
return
StmtMutator
::
VisitStmt_
(
op
);
return
StmtMutator
::
VisitStmt_
(
op
);
...
@@ -437,7 +450,7 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -437,7 +450,7 @@ class WSCodeEmitter : public StmtMutator {
}
}
// TODO: only need to add block for ops in the loop
// TODO: only need to add block for ops in the loop
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
bool
has_producer
=
false
;
bool
has_producer
=
false
;
for
(
auto
stmt
:
op
->
seq
)
{
for
(
auto
stmt
:
op
->
seq
)
{
if
(
marker_
.
GetRole
(
stmt
)
==
Role
::
kProducer
)
{
if
(
marker_
.
GetRole
(
stmt
)
==
Role
::
kProducer
)
{
...
@@ -445,19 +458,24 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -445,19 +458,24 @@ class WSCodeEmitter : public StmtMutator {
break
;
break
;
}
}
}
}
bool
need_producer_sync
=
has_producer
&&
marker_
.
GetRole
(
op
)
==
Role
::
kBoth
;
bool
need_producer_sync
=
if
(
!
need_producer_sync
)
return
FilterByRole
(
op
);
has_producer
&&
marker_
.
GetRole
(
op
)
==
Role
::
kBoth
;
if
(
!
need_producer_sync
)
return
FilterByRole
(
op
);
auto
seq_transformed
=
op
->
seq
.
Map
([
&
](
Stmt
stmt
)
{
return
VisitStmt
(
stmt
);
});
auto
seq_transformed
=
op
->
seq
.
Map
([
&
](
Stmt
stmt
)
{
return
VisitStmt
(
stmt
);
});
auto
map
=
ExtractSyncPattern
(
op
->
seq
);
auto
map
=
ExtractSyncPattern
(
op
->
seq
);
// std::cout << "Print ExtractSyncPattern" << std::endl;
// std::cout << "Print ExtractSyncPattern" << std::endl;
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " << map.release_after[i] << std::endl;
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
// << map.release_after[i] << std::endl;
// }
// }
// std::cout << "Print sync pattern" << std::endl;
// std::cout << "Print sync pattern" << std::endl;
// for (auto pattern : map.patterns) {
// for (auto pattern : map.patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// }
// }
// std::cout << "End of ExtractSyncPattern" << std::endl;
// std::cout << "End of ExtractSyncPattern" << std::endl;
// pipeline_info_.PrintPipelineInfo();
// pipeline_info_.PrintPipelineInfo();
...
@@ -465,29 +483,38 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -465,29 +483,38 @@ class WSCodeEmitter : public StmtMutator {
Map
<
String
,
ObjectRef
>
annotations
;
Map
<
String
,
ObjectRef
>
annotations
;
annotations
.
Set
(
String
(
"stmt_group"
),
Integer
(
1
));
annotations
.
Set
(
String
(
"stmt_group"
),
Integer
(
1
));
if
(
is_emitting_producer_
)
{
// producer case
if
(
is_emitting_producer_
)
{
// producer case
ProducerTraitsCollector
collector
;
ProducerTraitsCollector
collector
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
op
->
seq
.
size
());
i
++
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
op
->
seq
.
size
());
i
++
)
{
Array
<
Stmt
>
block_stmt
=
{};
Array
<
Stmt
>
block_stmt
=
{};
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kConsumer
)
continue
;
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kConsumer
)
continue
;
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kBoth
)
{
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kBoth
)
{
block_stmt
.
push_back
(
seq_transformed
[
i
]);
block_stmt
.
push_back
(
seq_transformed
[
i
]);
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
continue
;
continue
;
}
}
if
(
map
.
acquire
[
i
]
!=
-
1
)
{
if
(
map
.
acquire
[
i
]
!=
-
1
)
{
PrimExpr
acquire_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
PrimExpr
acquire_barrier_id
=
PrimExpr
parity
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
PrimExpr
parity
=
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
block_stmt
.
push_back
(
makeParityWait
(
acquire_barrier_id
,
parity
));
block_stmt
.
push_back
(
makeParityWait
(
acquire_barrier_id
,
parity
));
}
}
ICHECK
(
map
.
release
[
i
]
>=
0
);
ICHECK
(
map
.
release
[
i
]
>=
0
);
PrimExpr
release_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
PrimExpr
release_barrier_id
=
auto
stmt
=
MbarrierRewriter
::
Rewrite
(
seq_transformed
[
i
],
release_barrier_id
);
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
auto
stmt
=
MbarrierRewriter
::
Rewrite
(
seq_transformed
[
i
],
release_barrier_id
);
collector
.
Collect
(
stmt
);
collector
.
Collect
(
stmt
);
if
(
!
is_zero
(
collector
.
BulkCopyBytes
()))
{
if
(
!
is_zero
(
collector
.
BulkCopyBytes
()))
{
auto
expect_tx
=
IfThenElse
(
EQ
(
thread_var_
,
0
),
auto
expect_tx
=
IfThenElse
(
makeExpectTX
(
release_barrier_id
,
collector
.
BulkCopyBytes
()));
EQ
(
thread_var_
,
0
),
makeExpectTX
(
release_barrier_id
,
collector
.
BulkCopyBytes
()));
block_stmt
.
push_back
(
expect_tx
);
block_stmt
.
push_back
(
expect_tx
);
}
}
block_stmt
.
push_back
(
stmt
);
block_stmt
.
push_back
(
stmt
);
...
@@ -497,39 +524,53 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -497,39 +524,53 @@ class WSCodeEmitter : public StmtMutator {
if
(
map
.
release_after
[
i
])
{
if
(
map
.
release_after
[
i
])
{
block_stmt
.
push_back
(
makeArriveBarrier
(
release_barrier_id
));
block_stmt
.
push_back
(
makeArriveBarrier
(
release_barrier_id
));
for
(
int
j
=
0
;
j
<
num_stages_
;
j
++
)
{
for
(
int
j
=
0
;
j
<
num_stages_
;
j
++
)
{
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
}
}
}
}
collector
.
Clear
();
collector
.
Clear
();
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
}
}
}
else
{
// consumer case
}
else
{
// consumer case
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
op
->
seq
.
size
());
i
++
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
op
->
seq
.
size
());
i
++
)
{
Array
<
Stmt
>
block_stmt
=
{};
Array
<
Stmt
>
block_stmt
=
{};
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kProducer
)
continue
;
if
(
marker_
.
GetRole
(
op
->
seq
[
i
])
==
Role
::
kProducer
)
continue
;
if
(
map
.
acquire
[
i
]
!=
-
1
)
{
if
(
map
.
acquire
[
i
]
!=
-
1
)
{
PrimExpr
acquire_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
PrimExpr
acquire_barrier_id
=
PrimExpr
parity
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
acquire
[
i
];
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
PrimExpr
parity
=
map
.
is_loop_dependency
(
map
.
acquire
[
i
])
?
bitwise_xor
(
parity_
,
1
)
:
parity_
;
block_stmt
.
push_back
(
makeParityWait
(
acquire_barrier_id
,
parity
));
block_stmt
.
push_back
(
makeParityWait
(
acquire_barrier_id
,
parity
));
}
}
block_stmt
.
push_back
(
seq_transformed
[
i
]);
block_stmt
.
push_back
(
seq_transformed
[
i
]);
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ?
// block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
if
(
map
.
release_after
[
i
])
{
if
(
map
.
release_after
[
i
])
{
PrimExpr
release_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
PrimExpr
release_barrier_id
=
stage_
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
];
block_stmt
.
push_back
(
makeArriveBarrier
(
release_barrier_id
));
block_stmt
.
push_back
(
makeArriveBarrier
(
release_barrier_id
));
for
(
int
j
=
0
;
j
<
num_stages_
;
j
++
)
{
for
(
int
j
=
0
;
j
<
num_stages_
;
j
++
)
{
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
released_barrier_
.
insert
(
j
+
num_barriers_
+
num_stages_
*
map
.
release
[
i
]);
}
}
// Update the pipeline info
// Update the pipeline info
// Todo: handle sync
// Todo: handle sync
}
}
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
new_body
.
push_back
(
MakeGroupBlock
(
block_stmt
.
size
()
==
1
?
block_stmt
[
0
]
:
SeqStmt
(
std
::
move
(
block_stmt
)),
annotations
));
}
}
// Filter out the producer stmts
// Filter out the producer stmts
int
cur_id
=
0
;
int
cur_id
=
0
;
PipelineInfo
new_pipeline_info
;
PipelineInfo
new_pipeline_info
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
.
size
());
i
++
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
pipeline_info_
.
op_infos
.
size
());
i
++
)
{
auto
op_info
=
pipeline_info_
.
op_infos
[
i
];
auto
op_info
=
pipeline_info_
.
op_infos
[
i
];
bool
is_producer
=
false
;
bool
is_producer
=
false
;
for
(
int
j
=
0
;
j
<
op_info
.
group_size
;
j
++
)
{
for
(
int
j
=
0
;
j
<
op_info
.
group_size
;
j
++
)
{
...
@@ -553,7 +594,7 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -553,7 +594,7 @@ class WSCodeEmitter : public StmtMutator {
return
new_body
.
size
()
==
1
?
new_body
[
0
]
:
SeqStmt
(
std
::
move
(
new_body
));
return
new_body
.
size
()
==
1
?
new_body
[
0
]
:
SeqStmt
(
std
::
move
(
new_body
));
}
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
int
num_stages
=
1
;
int
num_stages
=
1
;
auto
num_stages_anno
=
op
->
annotations
.
Get
(
"num_stages"
);
auto
num_stages_anno
=
op
->
annotations
.
Get
(
"num_stages"
);
if
(
num_stages_anno
.
defined
())
{
if
(
num_stages_anno
.
defined
())
{
...
@@ -565,7 +606,7 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -565,7 +606,7 @@ class WSCodeEmitter : public StmtMutator {
Array
<
Array
<
Integer
>>
group_info_array
;
Array
<
Array
<
Integer
>>
group_info_array
;
Array
<
Integer
>
order_info_array
;
Array
<
Integer
>
order_info_array
;
Array
<
Integer
>
stage_info_array
;
Array
<
Integer
>
stage_info_array
;
auto
group_anno
=
op
->
annotations
.
Get
(
"tl_pipeline_group"
);
auto
group_anno
=
op
->
annotations
.
Get
(
"tl_pipeline_group"
);
if
(
group_anno
.
defined
())
{
if
(
group_anno
.
defined
())
{
group_info_array
=
Downcast
<
Array
<
Array
<
Integer
>>>
(
group_anno
);
group_info_array
=
Downcast
<
Array
<
Array
<
Integer
>>>
(
group_anno
);
...
@@ -579,9 +620,11 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -579,9 +620,11 @@ class WSCodeEmitter : public StmtMutator {
stage_info_array
=
Downcast
<
Array
<
Integer
>>
(
stage_anno
);
stage_info_array
=
Downcast
<
Array
<
Integer
>>
(
stage_anno
);
}
}
PipelineInfo
pipeline_info
(
group_info_array
,
order_info_array
,
stage_info_array
);
PipelineInfo
pipeline_info
(
group_info_array
,
order_info_array
,
stage_info_array
);
if
(
pipeline_info
.
op_infos
.
size
()
>
0
)
{
if
(
pipeline_info
.
op_infos
.
size
()
>
0
)
{
ICHECK
(
pipeline_info_
.
op_infos
.
size
()
==
0
)
<<
"Nested pipeline not supported."
;
ICHECK
(
pipeline_info_
.
op_infos
.
size
()
==
0
)
<<
"Nested pipeline not supported."
;
}
}
PrimExpr
parity_before
=
std
::
move
(
parity_
);
PrimExpr
parity_before
=
std
::
move
(
parity_
);
...
@@ -592,13 +635,15 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -592,13 +635,15 @@ class WSCodeEmitter : public StmtMutator {
num_stages_
=
num_stages
;
num_stages_
=
num_stages
;
pipeline_info_
=
pipeline_info
;
pipeline_info_
=
pipeline_info
;
stage_
=
FloorMod
(
op
->
loop_var
-
op
->
min
,
num_stages
);
stage_
=
FloorMod
(
op
->
loop_var
-
op
->
min
,
num_stages
);
parity_
=
parity_
=
FloorMod
(
parity_before
*
op
->
extent
+
FloorMod
(
parity_before
*
op
->
extent
+
FloorDiv
(
op
->
loop_var
-
op
->
min
,
num_stages
),
2
);
FloorDiv
(
op
->
loop_var
-
op
->
min
,
num_stages
),
2
);
auto
result
=
FilterByRole
(
op
);
auto
result
=
FilterByRole
(
op
);
Stmt
grouped_for_node
;
Stmt
grouped_for_node
;
if
(
result
.
as
<
ForNode
>
()
&&
group_anno
.
defined
()
&&
group_info_array
.
size
()
>
0
&&
!
is_emitting_producer_
)
{
if
(
result
.
as
<
ForNode
>
()
&&
group_anno
.
defined
()
&&
group_info_array
.
size
()
>
0
&&
!
is_emitting_producer_
)
{
GroupOpRewriter
group_op_rewriter
(
pipeline_info_
);
GroupOpRewriter
group_op_rewriter
(
pipeline_info_
);
auto
for_node
=
Downcast
<
For
>
(
result
);
auto
for_node
=
Downcast
<
For
>
(
result
);
grouped_for_node
=
group_op_rewriter
(
for_node
);
grouped_for_node
=
group_op_rewriter
(
for_node
);
...
@@ -618,7 +663,8 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -618,7 +663,8 @@ class WSCodeEmitter : public StmtMutator {
for_node
.
CopyOnWrite
()
->
annotations
.
erase
(
"tl_pipeline_order"
);
for_node
.
CopyOnWrite
()
->
annotations
.
erase
(
"tl_pipeline_order"
);
for_node
.
CopyOnWrite
()
->
annotations
.
erase
(
"tl_pipeline_stage"
);
for_node
.
CopyOnWrite
()
->
annotations
.
erase
(
"tl_pipeline_stage"
);
}
}
if
(
is_emitting_producer_
||
!
group_anno
.
defined
()
||
group_info_array
.
size
()
==
0
)
{
if
(
is_emitting_producer_
||
!
group_anno
.
defined
()
||
group_info_array
.
size
()
==
0
)
{
return
for_node
;
return
for_node
;
}
}
return
grouped_for_node
;
return
grouped_for_node
;
...
@@ -626,17 +672,17 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -626,17 +672,17 @@ class WSCodeEmitter : public StmtMutator {
return
result
;
return
result
;
}
}
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
IfThenElseNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
AssertStmtNode
*
op
)
final
{
return
FilterByRole
(
op
);
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
ICHECK
(
0
);
ICHECK
(
0
);
return
Stmt
();
return
Stmt
();
}
}
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
ICHECK
(
0
);
ICHECK
(
0
);
return
Stmt
();
return
Stmt
();
}
}
...
@@ -656,27 +702,32 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -656,27 +702,32 @@ class WSCodeEmitter : public StmtMutator {
}
}
};
};
std
::
vector
<
SyncPattern
>
CreateBaseSyncPairs
(
Array
<
Stmt
>
seq_stmt
,
std
::
vector
<
SyncPattern
>
const
std
::
vector
<
bool
>&
is_producer
)
{
CreateBaseSyncPairs
(
Array
<
Stmt
>
seq_stmt
,
const
std
::
vector
<
bool
>
&
is_producer
)
{
const
int
n
=
seq_stmt
.
size
();
const
int
n
=
seq_stmt
.
size
();
std
::
vector
<
std
::
set
<
const
BufferNode
*>>
reads
,
writes
;
std
::
vector
<
std
::
set
<
const
BufferNode
*>>
reads
,
writes
;
reads
.
reserve
(
n
);
reads
.
reserve
(
n
);
writes
.
reserve
(
n
);
writes
.
reserve
(
n
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
seq_stmt
[
i
]);
/*body*/
seq_stmt
[
i
]);
auto
access
=
GetBlockAccessRegion
(
block
,
buffer_data_to_buffer_
);
auto
access
=
GetBlockAccessRegion
(
block
,
buffer_data_to_buffer_
);
std
::
set
<
const
BufferNode
*>
read_set
,
write_set
;
std
::
set
<
const
BufferNode
*>
read_set
,
write_set
;
for
(
auto
region
:
access
[
0
])
read_set
.
insert
(
region
->
buffer
.
get
());
for
(
auto
region
:
access
[
0
])
for
(
auto
region
:
access
[
1
])
write_set
.
insert
(
region
->
buffer
.
get
());
read_set
.
insert
(
region
->
buffer
.
get
());
for
(
auto
region
:
access
[
1
])
write_set
.
insert
(
region
->
buffer
.
get
());
reads
.
push_back
(
std
::
move
(
read_set
));
reads
.
push_back
(
std
::
move
(
read_set
));
writes
.
push_back
(
std
::
move
(
write_set
));
writes
.
push_back
(
std
::
move
(
write_set
));
}
}
auto
intersect_fn
=
[](
const
std
::
set
<
const
BufferNode
*>
&
lhs
,
auto
intersect_fn
=
[](
const
std
::
set
<
const
BufferNode
*>
&
lhs
,
const
std
::
set
<
const
BufferNode
*>
&
rhs
)
{
const
std
::
set
<
const
BufferNode
*>
&
rhs
)
{
for
(
auto
ptr
:
lhs
)
for
(
auto
ptr
:
lhs
)
if
(
rhs
.
count
(
ptr
))
return
true
;
if
(
rhs
.
count
(
ptr
))
return
true
;
return
false
;
return
false
;
};
};
...
@@ -686,7 +737,8 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -686,7 +737,8 @@ class WSCodeEmitter : public StmtMutator {
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
i
+
1
;
j
<
n
;
j
++
)
{
for
(
int
j
=
i
+
1
;
j
<
n
;
j
++
)
{
if
(
is_producer
[
i
]
!=
is_producer
[
j
]
&&
if
(
is_producer
[
i
]
!=
is_producer
[
j
]
&&
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
sync_patterns
.
push_back
({
i
,
j
});
sync_patterns
.
push_back
({
i
,
j
});
break
;
break
;
}
}
...
@@ -701,7 +753,8 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -701,7 +753,8 @@ class WSCodeEmitter : public StmtMutator {
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
int
j
=
0
;
j
<
i
;
j
++
)
{
for
(
int
j
=
0
;
j
<
i
;
j
++
)
{
if
(
is_producer
[
i
]
!=
is_producer
[
j
]
&&
if
(
is_producer
[
i
]
!=
is_producer
[
j
]
&&
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
(
intersect_fn
(
writes
[
i
],
reads
[
j
])
||
intersect_fn
(
reads
[
i
],
writes
[
j
])))
{
sync_patterns
.
push_back
({
i
,
j
});
sync_patterns
.
push_back
({
i
,
j
});
break
;
break
;
}
}
...
@@ -712,8 +765,9 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -712,8 +765,9 @@ class WSCodeEmitter : public StmtMutator {
return
sync_patterns
;
return
sync_patterns
;
}
}
static
std
::
vector
<
SyncPattern
>
RemoveUnusedSyncPatterns
(
static
std
::
vector
<
SyncPattern
>
const
std
::
vector
<
SyncPattern
>&
sync_patterns
,
const
std
::
vector
<
bool
>&
is_producer
)
{
RemoveUnusedSyncPatterns
(
const
std
::
vector
<
SyncPattern
>
&
sync_patterns
,
const
std
::
vector
<
bool
>
&
is_producer
)
{
/*
/*
Simplify multiple release-acquire pairs into one
Simplify multiple release-acquire pairs into one
------------------
------------------
...
@@ -746,7 +800,8 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -746,7 +800,8 @@ class WSCodeEmitter : public StmtMutator {
std
::
vector
<
SyncPattern
>
sync_pattern_cleaned
;
std
::
vector
<
SyncPattern
>
sync_pattern_cleaned
;
sync_pattern_cleaned
.
reserve
(
M
);
sync_pattern_cleaned
.
reserve
(
M
);
for
(
int
i
=
0
;
i
<
M
;
i
++
)
for
(
int
i
=
0
;
i
<
M
;
i
++
)
if
(
!
removed
[
i
])
sync_pattern_cleaned
.
push_back
(
sync_patterns
[
i
]);
if
(
!
removed
[
i
])
sync_pattern_cleaned
.
push_back
(
sync_patterns
[
i
]);
return
sync_pattern_cleaned
;
return
sync_pattern_cleaned
;
}
}
...
@@ -760,10 +815,12 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -760,10 +815,12 @@ class WSCodeEmitter : public StmtMutator {
}
}
auto
sync_patterns_base
=
CreateBaseSyncPairs
(
seq_stmt
,
is_producer
);
auto
sync_patterns_base
=
CreateBaseSyncPairs
(
seq_stmt
,
is_producer
);
auto
sync_patterns
=
RemoveUnusedSyncPatterns
(
sync_patterns_base
,
is_producer
);
auto
sync_patterns
=
RemoveUnusedSyncPatterns
(
sync_patterns_base
,
is_producer
);
// for (auto pattern : sync_patterns) {
// for (auto pattern : sync_patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// }
// }
SyncPatternMap
map
;
SyncPatternMap
map
;
...
@@ -799,7 +856,7 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -799,7 +856,7 @@ class WSCodeEmitter : public StmtMutator {
const
bool
is_emitting_producer_
;
const
bool
is_emitting_producer_
;
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
std
::
unordered_set
<
int
>
released_barrier_
;
std
::
unordered_set
<
int
>
released_barrier_
;
const
WarpSpecializedRoleMarker
&
marker_
;
const
WarpSpecializedRoleMarker
&
marker_
;
int
num_barriers_
=
0
;
int
num_barriers_
=
0
;
PrimExpr
parity_
=
0
;
PrimExpr
parity_
=
0
;
...
@@ -811,17 +868,18 @@ class WSCodeEmitter : public StmtMutator {
...
@@ -811,17 +868,18 @@ class WSCodeEmitter : public StmtMutator {
};
};
class
WarpSpecializedRewriter
:
public
StmtExprMutator
{
class
WarpSpecializedRewriter
:
public
StmtExprMutator
{
public:
public:
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
static
PrimFunc
Substitute
(
PrimFunc
f
)
{
auto
T
=
WarpSpecializedRewriter
();
auto
T
=
WarpSpecializedRewriter
();
T
.
buffer_lca_
=
DetectBufferAccessLCA
(
f
);
T
.
buffer_lca_
=
DetectBufferAccessLCA
(
f
);
for
(
auto
[
buffer
,
_
]
:
T
.
buffer_lca_
)
T
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
for
(
auto
[
buffer
,
_
]
:
T
.
buffer_lca_
)
T
.
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
f
.
CopyOnWrite
()
->
body
=
T
(
f
->
body
);
f
.
CopyOnWrite
()
->
body
=
T
(
f
->
body
);
return
f
;
return
f
;
}
}
private:
private:
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
&&
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
&&
Downcast
<
IterVar
>
(
op
->
node
)
->
thread_tag
==
"threadIdx.x"
)
{
Downcast
<
IterVar
>
(
op
->
node
)
->
thread_tag
==
"threadIdx.x"
)
{
thread_iv_
=
Downcast
<
IterVar
>
(
op
->
node
);
thread_iv_
=
Downcast
<
IterVar
>
(
op
->
node
);
...
@@ -839,9 +897,10 @@ class WarpSpecializedRewriter : public StmtExprMutator {
...
@@ -839,9 +897,10 @@ class WarpSpecializedRewriter : public StmtExprMutator {
}
}
}
}
// If users define a thread binding, we will replace the thread binding with threadIdx.x
// If users define a thread binding, we will replace the thread binding with
// We require the thread binding is threadIdx.x, and the extent is the same as the thread extent
// threadIdx.x We require the thread binding is threadIdx.x, and the extent is
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
// the same as the thread extent
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
ICHECK
(
thread_iv_
.
defined
());
ICHECK
(
thread_iv_
.
defined
());
For
for_node
=
Downcast
<
For
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
For
for_node
=
Downcast
<
For
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
if
(
for_node
->
kind
==
ForKind
::
kThreadBinding
)
{
if
(
for_node
->
kind
==
ForKind
::
kThreadBinding
)
{
...
@@ -849,14 +908,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
...
@@ -849,14 +908,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
String
thread_tag
=
for_node
->
thread_binding
.
value
()
->
thread_tag
;
String
thread_tag
=
for_node
->
thread_binding
.
value
()
->
thread_tag
;
ICHECK
(
thread_tag
==
"threadIdx.x"
)
<<
"Only support threadIdx.x"
;
ICHECK
(
thread_tag
==
"threadIdx.x"
)
<<
"Only support threadIdx.x"
;
Var
thread_iv
=
Downcast
<
Var
>
(
for_node
->
loop_var
);
Var
thread_iv
=
Downcast
<
Var
>
(
for_node
->
loop_var
);
Stmt
new_body
=
ThreadIdxRewriter
::
Rewrite
(
for_node
->
body
,
thread_iv
,
thread_iv_
);
Stmt
new_body
=
ThreadIdxRewriter
::
Rewrite
(
for_node
->
body
,
thread_iv
,
thread_iv_
);
return
new_body
;
return
new_body
;
}
}
return
for_node
;
return
for_node
;
}
}
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BlockRealizeNode
*
op
)
final
{
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
BlockRealize
block_realize
=
Downcast
<
BlockRealize
>
(
StmtExprMutator
::
VisitStmt_
(
op
));
if
(
!
thread_iv_
.
defined
())
{
if
(
!
thread_iv_
.
defined
())
{
return
block_realize
;
return
block_realize
;
}
}
...
@@ -877,17 +938,21 @@ class WarpSpecializedRewriter : public StmtExprMutator {
...
@@ -877,17 +938,21 @@ class WarpSpecializedRewriter : public StmtExprMutator {
PrimExpr
consumer_thread_extent
=
thread_iv_
->
dom
->
extent
;
PrimExpr
consumer_thread_extent
=
thread_iv_
->
dom
->
extent
;
PrimExpr
producer_thread_extent
=
thread_iv_
->
dom
->
extent
;
PrimExpr
producer_thread_extent
=
thread_iv_
->
dom
->
extent
;
// Need one warp-group for bulk-copy only case
// Need one warp-group for bulk-copy only case
if
(
!
marker
.
HasSimtCopy
())
producer_thread_extent
=
128
;
if
(
!
marker
.
HasSimtCopy
())
producer_thread_extent
=
128
;
// TODO: estimate the correct reg usage.
// TODO: estimate the correct reg usage.
auto
inc_reg_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
240
,
1
}));
auto
inc_reg_stmt
=
auto
dec_reg_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
24
,
0
}));
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
240
,
1
}));
auto
dec_reg_stmt
=
Evaluate
(
Call
(
DataType
::
Handle
(),
SetMaxNReg
(),
{
24
,
0
}));
producer_code
=
SeqStmt
({
dec_reg_stmt
,
producer_code
});
producer_code
=
SeqStmt
({
dec_reg_stmt
,
producer_code
});
consumer_code
=
SeqStmt
({
inc_reg_stmt
,
consumer_code
});
consumer_code
=
SeqStmt
({
inc_reg_stmt
,
consumer_code
});
producer_code
=
ThreadIdxRewriter
::
Rewrite
(
producer_code
,
thread_iv_
->
var
,
producer_code
=
thread_iv_
->
var
-
consumer_thread_extent
);
ThreadIdxRewriter
::
Rewrite
(
producer_code
,
thread_iv_
->
var
,
thread_iv_
->
var
-
consumer_thread_extent
);
updated_thread_extent_
=
consumer_thread_extent
+
producer_thread_extent
;
updated_thread_extent_
=
consumer_thread_extent
+
producer_thread_extent
;
need_update_thread_extent_
=
true
;
need_update_thread_extent_
=
true
;
...
@@ -897,15 +962,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
...
@@ -897,15 +962,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
Array
<
PrimExpr
>
barrier_num_threads
;
Array
<
PrimExpr
>
barrier_num_threads
;
barrier_num_threads
.
reserve
(
num_barriers
);
barrier_num_threads
.
reserve
(
num_barriers
);
for
(
int
i
=
0
;
i
<
num_barriers
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_barriers
;
i
++
)
{
PrimExpr
arrive_thread_count
=
PrimExpr
arrive_thread_count
=
producer
.
released_barrier_
.
count
(
i
)
producer
.
released_barrier_
.
count
(
i
)
?
producer_thread_extent
:
consumer_thread_extent
;
?
producer_thread_extent
:
consumer_thread_extent
;
barrier_num_threads
.
push_back
(
arrive_thread_count
);
barrier_num_threads
.
push_back
(
arrive_thread_count
);
}
}
Stmt
init_barrier
=
Stmt
init_barrier
=
Evaluate
(
Call
(
Evaluate
(
Call
(
DataType
::
Handle
(),
CreateListofMBarrierOp
(),
barrier_num_threads
));
DataType
::
Handle
(),
CreateListofMBarrierOp
(),
barrier_num_threads
));
Stmt
body
=
Stmt
body
=
IfThenElse
(
GE
(
thread_iv_
->
var
,
consumer_thread_extent
),
IfThenElse
(
GE
(
thread_iv_
->
var
,
consumer_thread_extent
),
producer_code
,
consumer_code
);
producer_code
,
consumer_code
);
// Add an attr here to handle the partial thread count in THreadSync pass.
// Add an attr here to handle the partial thread count in THreadSync pass.
Array
<
IntImm
>
ws_partition
=
{
Downcast
<
IntImm
>
(
producer_thread_extent
),
Array
<
IntImm
>
ws_partition
=
{
Downcast
<
IntImm
>
(
producer_thread_extent
),
Downcast
<
IntImm
>
(
consumer_thread_extent
)};
Downcast
<
IntImm
>
(
consumer_thread_extent
)};
...
@@ -935,7 +1001,8 @@ tvm::transform::Pass WarpSpecialized() {
...
@@ -935,7 +1001,8 @@ tvm::transform::Pass WarpSpecialized() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.WarpSpecialized"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.WarpSpecialized"
,
{});
}
}
TVM_REGISTER_GLOBAL
(
"tl.transform.WarpSpecialized"
).
set_body_typed
(
WarpSpecialized
);
TVM_REGISTER_GLOBAL
(
"tl.transform.WarpSpecialized"
)
.
set_body_typed
(
WarpSpecialized
);
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
549416f7
...
@@ -5,7 +5,6 @@ import torch
...
@@ -5,7 +5,6 @@ import torch
import
torch.backends
import
torch.backends
import
tilelang.testing
import
tilelang.testing
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
from
tvm
import
DataType
import
tilelang
as
TL
import
tilelang
as
TL
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
from
tilelang.intrinsics
import
make_mfma_swizzle_layout
as
make_swizzle_layout
...
...
testing/python/kernel/test_tilelang_gemm.py
View file @
549416f7
...
@@ -30,13 +30,11 @@ def matmul(
...
@@ -30,13 +30,11 @@ def matmul(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -169,9 +167,7 @@ def test_gemm_f32f32f32_nn():
...
@@ -169,9 +167,7 @@ def test_gemm_f32f32f32_nn():
def
test_gemm_i8i8i32_nn
():
def
test_gemm_i8i8i32_nn
():
run_gemm
(
run_gemm
(
512
,
1024
,
768
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
)
512
,
1024
,
768
,
False
,
False
,
"int8"
,
"int8"
,
"int32"
,
128
,
128
,
64
)
def
test_gemm_f16f16f16_tn
():
def
test_gemm_f16f16f16_tn
():
...
@@ -217,9 +213,7 @@ def test_gemm_i8i8i32_tn():
...
@@ -217,9 +213,7 @@ def test_gemm_i8i8i32_tn():
def
test_gemm_f64f64f64_nt
():
def
test_gemm_f64f64f64_nt
():
run_gemm
(
run_gemm
(
512
,
512
,
512
,
False
,
True
,
"float64"
,
"float64"
,
"float64"
,
64
,
32
,
16
)
512
,
512
,
512
,
False
,
True
,
"float64"
,
"float64"
,
"float64"
,
64
,
32
,
16
)
def
test_gemm_f32f32f32_nt
():
def
test_gemm_f32f32f32_nt
():
...
...
testing/python/kernel/test_tilelang_gemm_mma_intrinsic.py
View file @
549416f7
...
@@ -10,8 +10,7 @@ import tilelang as TL
...
@@ -10,8 +10,7 @@ import tilelang as TL
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics
import
get_swizzle_layout
from
tilelang.intrinsics.mma_macro_generator
import
(
from
tilelang.intrinsics.mma_macro_generator
import
(
TensorCoreIntrinEmitter
,
TensorCoreIntrinEmitter
,)
)
from
tilelang.transform
import
simplify_prim_func
from
tilelang.transform
import
simplify_prim_func
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
testing/python/primitives/test_tilelang_primitives_mma.py
View file @
549416f7
...
@@ -6,6 +6,7 @@ import tilelang.testing
...
@@ -6,6 +6,7 @@ import tilelang.testing
import
tilelang
as
tl
import
tilelang
as
tl
from
tilelang
import
primitives
as
P
from
tilelang
import
primitives
as
P
def
matmul_ssr
(
def
matmul_ssr
(
M
,
M
,
N
,
N
,
...
@@ -30,13 +31,11 @@ def matmul_ssr(
...
@@ -30,13 +31,11 @@ def matmul_ssr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
...
@@ -145,13 +144,11 @@ def matmul_rsr(
...
@@ -145,13 +144,11 @@ def matmul_rsr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
@@ -264,13 +261,11 @@ def matmul_rrr(
...
@@ -264,13 +261,11 @@ def matmul_rrr(
@
T
.
prim_func
@
T
.
prim_func
def
main
(
def
main
(
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
A
:
T
.
Buffer
(
A_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
B
:
T
.
Buffer
(
B_shape
,
in_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
C
:
T
.
Buffer
((
M
,
N
),
out_dtype
),
):
):
with
T
.
Kernel
(
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
A_local
=
T
.
alloc_fragment
(
A_local_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
...
...
tilelang/intrinsics/__init__.py
View file @
549416f7
...
@@ -11,8 +11,7 @@ from .mma_macro_generator import (
...
@@ -11,8 +11,7 @@ from .mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform
,
# noqa: F401
TensorCoreIntrinEmitterWithLadderTransform
,
# noqa: F401
)
)
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
get_swizzle_layout
# noqa: F401
from
.mma_layout
import
make_mma_swizzle_layout
# noqa: F401
from
.mma_layout
import
make_mma_swizzle_layout
# noqa: F401
from
.mfma_layout
import
make_mfma_swizzle_layout
# noqa: F401
from
.mfma_layout
import
make_mfma_swizzle_layout
# noqa: F401
tilelang/intrinsics/mma_macro_generator.py
View file @
549416f7
...
@@ -14,6 +14,7 @@ from .utils import (
...
@@ -14,6 +14,7 @@ from .utils import (
lift
=
convert
lift
=
convert
# TODO(lei): Add Typing for this file
# TODO(lei): Add Typing for this file
class
TensorCoreIntrinEmitter
(
object
):
class
TensorCoreIntrinEmitter
(
object
):
"""
"""
...
@@ -75,9 +76,11 @@ class TensorCoreIntrinEmitter(object):
...
@@ -75,9 +76,11 @@ class TensorCoreIntrinEmitter(object):
self
.
reduce_k
=
reduce_k
self
.
reduce_k
=
reduce_k
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
threads
=
self
.
WARP_SIZE
*
(
block_row_warps
*
block_col_warps
)
*
reduce_k
self
.
num_elems_per_byte
=
num_elems_per_byte
self
.
num_elems_per_byte
=
num_elems_per_byte
if
self
.
warp_rows
==
0
or
self
.
warp_cols
==
0
:
if
self
.
warp_rows
==
0
or
self
.
warp_cols
==
0
:
raise
ValueError
(
f
"Invalid threads configuration for this tile shape,
{
self
.
warp_rows
}
x
{
self
.
warp_cols
}
with threads
{
self
.
threads
}
"
)
raise
ValueError
(
f
"Invalid threads configuration for this tile shape,
{
self
.
warp_rows
}
x
{
self
.
warp_cols
}
with threads
{
self
.
threads
}
"
)
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
def
_initialize_k_dim
(
self
,
a_dtype
=
"float16"
):
if
isinstance
(
a_dtype
,
str
):
if
isinstance
(
a_dtype
,
str
):
...
@@ -272,12 +275,9 @@ class TensorCoreIntrinEmitter(object):
...
@@ -272,12 +275,9 @@ class TensorCoreIntrinEmitter(object):
A_local_buf
.
data
,
A_local_buf
.
data
,
k_inner
*
warp_rows
*
local_size_a
+
i
*
local_size_a
,
k_inner
*
warp_rows
*
local_size_a
+
i
*
local_size_a
,
B_local_buf
.
data
,
B_local_buf
.
data
,
k_inner
*
warp_cols
*
local_size_b
+
j
*
local_size_b
k_inner
*
warp_cols
*
local_size_b
+
j
*
local_size_b
+
lift
(
local_size_b
)
//
2
,
+
lift
(
local_size_b
)
//
2
,
C_local_buf
.
data
,
C_local_buf
.
data
,
i
*
warp_cols
*
local_size_out
i
*
warp_cols
*
local_size_out
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
+
j
*
local_size_out
+
lift
(
local_size_out
)
//
2
,
T
.
bool
(
False
),
T
.
bool
(
False
),
)
)
...
@@ -328,7 +328,9 @@ class TensorCoreIntrinEmitter(object):
...
@@ -328,7 +328,9 @@ class TensorCoreIntrinEmitter(object):
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_bindings
)
return
(
_warp_stmatrix_global
(
C_local_buf
,
C_buf
,
thread_bindings
)
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_bindings
))
if
is_global
else
_warp_stmatrix_shared
(
C_local_buf
,
C_buf
,
thread_bindings
))
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
def
make_mma_load_layout
(
self
,
local_buf
:
Buffer
,
matrix
:
Literal
[
"A"
,
"B"
]
=
"A"
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MMA results into a fragment buffer.
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
This layout is used in conjunction with `inverse_mma_store_layout` to
...
@@ -372,12 +374,14 @@ class TensorCoreIntrinEmitter(object):
...
@@ -372,12 +374,14 @@ class TensorCoreIntrinEmitter(object):
elif
matrix
==
"A"
and
not
transposed
:
elif
matrix
==
"A"
and
not
transposed
:
transform_func
=
ldmatrix_16x32_to_shared_16x32_layout_a
transform_func
=
ldmatrix_16x32_to_shared_16x32_layout_a
else
:
else
:
raise
ValueError
(
"ldmatrix only supports B transposed and A non-transposed for int8"
)
raise
ValueError
(
"ldmatrix only supports B transposed and A non-transposed for int8"
)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
shape
=
local_buf
.
shape
shape
=
local_buf
.
shape
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment, but got {}"
.
format
(
local_buf
.
scope
())
assert
is_fragment
(
local_buf
),
"local_buf must be a fragment, but got {}"
.
format
(
local_buf
.
scope
())
if
matrix
==
"A"
:
if
matrix
==
"A"
:
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_k
micro_size_x
,
micro_size_y
=
self
.
micro_size_x
,
self
.
micro_size_k
...
@@ -397,7 +401,8 @@ class TensorCoreIntrinEmitter(object):
...
@@ -397,7 +401,8 @@ class TensorCoreIntrinEmitter(object):
transform_func
=
transform_func
if
not
transposed
else
transform_func_trans
transform_func
=
transform_func
if
not
transposed
else
transform_func_trans
warp_size
,
local_size_a
,
local_size_b
=
self
.
WARP_SIZE
,
self
.
local_size_a
,
self
.
local_size_b
warp_size
,
local_size_a
,
local_size_b
=
self
.
WARP_SIZE
,
self
.
local_size_a
,
self
.
local_size_b
local_size
=
local_size_a
if
matrix
==
"A"
else
local_size_b
local_size
=
local_size_a
if
matrix
==
"A"
else
local_size_b
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
).
inverse
([
warp_size
,
local_size
])
inverse_mma_load_layout
=
IndexMap
.
from_func
(
transform_func
,
index_dtype
=
"int32"
).
inverse
([
warp_size
,
local_size
])
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
"""
...
@@ -406,29 +411,19 @@ class TensorCoreIntrinEmitter(object):
...
@@ -406,29 +411,19 @@ class TensorCoreIntrinEmitter(object):
"""
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
lane_id
,
_
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
if
is_m_first
:
if
is_m_first
:
thread_id
=
(
thread_id
=
(
block_i
*
(
block_col_warps
*
warp_cols
)
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
+
block_j
*
warp_rows
warp_i
*
warp_cols
+
warp_j
)
+
warp_i
*
warp_cols
+
warp_j
)
else
:
else
:
thread_id
=
(
thread_id
=
(
block_j
*
(
block_row_warps
*
warp_size
)
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
)
+
block_i
*
warp_size
+
lane_id
)
return
thread_id
return
thread_id
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
def
forward_index
(
i
:
int
,
j
:
int
)
->
int
:
...
@@ -439,21 +434,13 @@ class TensorCoreIntrinEmitter(object):
...
@@ -439,21 +434,13 @@ class TensorCoreIntrinEmitter(object):
"""
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
_
,
local_id
=
inverse_mma_load_layout
.
map_indices
([
mma_i
,
mma_j
])
return
(
return
(
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
)
warp_i
*
(
warp_cols
*
local_size_out
)
+
warp_j
*
local_size_out
+
local_id
)
fragment
=
T
.
Fragment
(
fragment
=
T
.
Fragment
(
shape
,
shape
,
...
@@ -465,9 +452,7 @@ class TensorCoreIntrinEmitter(object):
...
@@ -465,9 +452,7 @@ class TensorCoreIntrinEmitter(object):
print
(
f
"fragment.index:
{
fragment
.
index
}
"
)
print
(
f
"fragment.index:
{
fragment
.
index
}
"
)
return
fragment
return
fragment
def
make_mma_store_layout
(
def
make_mma_store_layout
(
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
self
,
local_buf
:
Buffer
)
->
T
.
Fragment
:
"""
"""
Create a layout function for storing MMA results into a fragment buffer.
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
This layout is used in conjunction with `inverse_mma_store_layout` to
...
@@ -500,6 +485,7 @@ class TensorCoreIntrinEmitter(object):
...
@@ -500,6 +485,7 @@ class TensorCoreIntrinEmitter(object):
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_rows
,
warp_cols
=
self
.
warp_rows
,
self
.
warp_cols
warp_size
=
self
.
WARP_SIZE
warp_size
=
self
.
WARP_SIZE
is_m_first
=
self
.
is_m_first
is_m_first
=
self
.
is_m_first
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
def
forward_thread
(
i
:
int
,
j
:
int
)
->
int
:
"""
"""
Given the row index `i` and column index `j` in the fragment,
Given the row index `i` and column index `j` in the fragment,
...
@@ -514,7 +500,8 @@ class TensorCoreIntrinEmitter(object):
...
@@ -514,7 +500,8 @@ class TensorCoreIntrinEmitter(object):
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
lane_id
,
_
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
lane_id
,
_
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
if
is_m_first
:
if
is_m_first
:
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
thread_id
=
block_i
*
(
block_col_warps
*
warp_cols
)
+
block_j
*
warp_rows
+
warp_i
*
warp_cols
+
warp_j
else
:
else
:
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
thread_id
=
block_j
*
(
block_row_warps
*
warp_size
)
+
block_i
*
warp_size
+
lane_id
return
thread_id
return
thread_id
...
@@ -527,13 +514,9 @@ class TensorCoreIntrinEmitter(object):
...
@@ -527,13 +514,9 @@ class TensorCoreIntrinEmitter(object):
"""
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
block_i
,
block_j
=
(
i
//
micro_size_x
)
//
warp_rows
,
(
j
//
micro_size_y
)
//
warp_cols
j
//
micro_size_y
)
//
warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
warp_i
,
warp_j
=
(
i
//
micro_size_x
)
%
warp_rows
,
(
j
//
micro_size_y
)
%
warp_cols
j
//
micro_size_y
)
%
warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
mma_i
,
mma_j
=
i
%
micro_size_x
,
j
%
micro_size_y
_
,
local_id
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
_
,
local_id
=
inverse_mma_store_layout
.
map_indices
([
mma_i
,
mma_j
])
...
@@ -545,6 +528,7 @@ class TensorCoreIntrinEmitter(object):
...
@@ -545,6 +528,7 @@ class TensorCoreIntrinEmitter(object):
forward_index_fn
=
forward_index
,
forward_index_fn
=
forward_index
,
)
)
class
TensorCoreIntrinEmitterWithLadderTransform
(
TensorCoreIntrinEmitter
):
class
TensorCoreIntrinEmitterWithLadderTransform
(
TensorCoreIntrinEmitter
):
"""
"""
To eliminate Python syntax within TIR Macro.
To eliminate Python syntax within TIR Macro.
...
...
tilelang/language/allocate.py
View file @
549416f7
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
tvm.script
import
tir
as
T
from
tvm.script
import
tir
as
T
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
...
...
tilelang/language/copy.py
View file @
549416f7
...
@@ -6,11 +6,10 @@ from typing import Union, List, Optional
...
@@ -6,11 +6,10 @@ from typing import Union, List, Optional
from
tvm
import
tir
from
tvm
import
tir
from
tvm.script
import
tir
as
T
from
tvm.script
import
tir
as
T
def
region
(
buffer
:
tir
.
BufferLoad
,
access_type
:
str
,
*
args
:
tir
.
PrimExpr
):
def
region
(
buffer
:
tir
.
BufferLoad
,
access_type
:
str
,
*
args
:
tir
.
PrimExpr
):
access_type
=
{
"r"
:
1
,
"w"
:
2
,
"rw"
:
3
}[
access_type
]
access_type
=
{
"r"
:
1
,
"w"
:
2
,
"rw"
:
3
}[
access_type
]
return
tir
.
call_intrin
(
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.region"
),
buffer
,
access_type
,
*
args
)
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.region"
),
buffer
,
access_type
,
*
args
)
def
buffer_to_tile_region
(
buffer
:
tir
.
Buffer
,
access_type
:
str
):
def
buffer_to_tile_region
(
buffer
:
tir
.
Buffer
,
access_type
:
str
):
...
@@ -19,20 +18,14 @@ def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
...
@@ -19,20 +18,14 @@ def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
return
region
(
T
.
BufferLoad
(
buffer
,
mins
),
access_type
,
*
extents
)
return
region
(
T
.
BufferLoad
(
buffer
,
mins
),
access_type
,
*
extents
)
def
buffer_load_to_tile_region
(
def
buffer_load_to_tile_region
(
load
:
tir
.
BufferLoad
,
access_type
:
str
,
extents
:
List
[
tir
.
PrimExpr
]):
load
:
tir
.
BufferLoad
,
access_type
:
str
,
extents
:
List
[
tir
.
PrimExpr
]
):
return
region
(
load
,
access_type
,
*
extents
)
return
region
(
load
,
access_type
,
*
extents
)
def
buffer_region_to_tile_region
(
def
buffer_region_to_tile_region
(
buffer_region
:
tir
.
BufferRegion
,
access_type
:
str
):
buffer_region
:
tir
.
BufferRegion
,
access_type
:
str
):
mins
=
[
x
.
min
for
x
in
buffer_region
.
region
]
mins
=
[
x
.
min
for
x
in
buffer_region
.
region
]
extents
=
[
x
.
extent
for
x
in
buffer_region
.
region
]
extents
=
[
x
.
extent
for
x
in
buffer_region
.
region
]
return
region
(
return
region
(
T
.
BufferLoad
(
buffer_region
.
buffer
,
mins
),
access_type
,
*
extents
)
T
.
BufferLoad
(
buffer_region
.
buffer
,
mins
),
access_type
,
*
extents
)
def
copy
(
def
copy
(
...
@@ -71,9 +64,7 @@ def copy(
...
@@ -71,9 +64,7 @@ def copy(
src
=
_to_region
(
src
,
"r"
)
src
=
_to_region
(
src
,
"r"
)
dst
=
_to_region
(
dst
,
"w"
)
dst
=
_to_region
(
dst
,
"w"
)
if
coalesced_width
is
not
None
:
if
coalesced_width
is
not
None
:
return
tir
.
call_intrin
(
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.copy"
),
src
,
dst
,
coalesced_width
)
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.copy"
),
src
,
dst
,
coalesced_width
)
else
:
else
:
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.copy"
),
src
,
dst
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.copy"
),
src
,
dst
)
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment