Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
758 additions
and
161 deletions
+758
-161
src/transform/atomicadd_vectorize.cc
src/transform/atomicadd_vectorize.cc
+1
-2
src/transform/cluster_planning.cc
src/transform/cluster_planning.cc
+7
-4
src/transform/common/loop_parallel_transform_utils.h
src/transform/common/loop_parallel_transform_utils.h
+2
-2
src/transform/common/loop_vectorization_utils.h
src/transform/common/loop_vectorization_utils.h
+28
-28
src/transform/config_index_bitwidth.cc
src/transform/config_index_bitwidth.cc
+7
-7
src/transform/eliminate_storage_sync_for_mbarrier.cc
src/transform/eliminate_storage_sync_for_mbarrier.cc
+4
-6
src/transform/flatten_buffer.cc
src/transform/flatten_buffer.cc
+7
-7
src/transform/frontend_legalize.cc
src/transform/frontend_legalize.cc
+2
-2
src/transform/if_stmt_binding.cc
src/transform/if_stmt_binding.cc
+3
-3
src/transform/inject_assumes.cc
src/transform/inject_assumes.cc
+2
-2
src/transform/inject_fence_proxy.cc
src/transform/inject_fence_proxy.cc
+4
-3
src/transform/inject_pipeline.cc
src/transform/inject_pipeline.cc
+52
-16
src/transform/inject_ptx_async_copy.cc
src/transform/inject_ptx_async_copy.cc
+2
-2
src/transform/inject_tma_barrier.cc
src/transform/inject_tma_barrier.cc
+100
-27
src/transform/layout_inference.cc
src/transform/layout_inference.cc
+264
-28
src/transform/layout_reducer.cc
src/transform/layout_reducer.cc
+31
-13
src/transform/layout_reducer.h
src/transform/layout_reducer.h
+4
-4
src/transform/legalize_negative_index.cc
src/transform/legalize_negative_index.cc
+233
-0
src/transform/legalize_safe_memory_access.cc
src/transform/legalize_safe_memory_access.cc
+3
-3
src/transform/legalize_vectorized_loop.cc
src/transform/legalize_vectorized_loop.cc
+2
-2
No files found.
src/transform/atomicadd_vectorize.cc
View file @
bbbf4207
...
@@ -249,7 +249,6 @@ private:
...
@@ -249,7 +249,6 @@ private:
new_args
.
push_back
(
dst_node
);
new_args
.
push_back
(
dst_node
);
new_args
.
push_back
(
value_node
);
new_args
.
push_back
(
value_node
);
}
}
new_args
.
push_back
(
memory_order
);
new_args
.
push_back
(
memory_order
);
Call
new_call
=
Call
new_call
=
...
@@ -284,4 +283,4 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
...
@@ -284,4 +283,4 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
}
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
\ No newline at end of file
src/transform/cluster_planning.cc
View file @
bbbf4207
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
namespace
tvm
{
namespace
tvm
{
namespace
tir
{
namespace
tir
{
...
@@ -66,7 +68,8 @@ public:
...
@@ -66,7 +68,8 @@ public:
}
}
if
(
mem_reuse_max
>
0
)
{
if
(
mem_reuse_max
>
0
)
{
std
::
string
tag_str
=
cluster_tag
;
// Convert to std::string
std
::
string
tag_str
=
static_cast
<
std
::
string
>
(
cluster_tag
);
// Convert to std::string
if
(
tag_str
.
rfind
(
"blockIdx"
,
0
)
==
0
)
{
if
(
tag_str
.
rfind
(
"blockIdx"
,
0
)
==
0
)
{
// starts with "blockIdx"
// starts with "blockIdx"
tag_str
=
"clusterIdx"
+
tag_str
.
substr
(
strlen
(
"blockIdx"
));
tag_str
=
"clusterIdx"
+
tag_str
.
substr
(
strlen
(
"blockIdx"
));
...
@@ -74,7 +77,7 @@ public:
...
@@ -74,7 +77,7 @@ public:
// Unexpected format — maybe just prefix
// Unexpected format — maybe just prefix
tag_str
=
"clusterIdx"
+
tag_str
;
tag_str
=
"clusterIdx"
+
tag_str
;
}
}
cluster_tag
=
tvm
::
ffi
::
String
(
tag_str
);
// Convert back
cluster_tag
=
String
(
tag_str
);
// Convert back
return
WithAttr
(
f
,
cluster_tag
,
Integer
(
cluster_size_
));
return
WithAttr
(
f
,
cluster_tag
,
Integer
(
cluster_size_
));
}
else
{
}
else
{
return
f
;
return
f
;
...
@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() {
...
@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ClusterPlanning"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ClusterPlanning"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.ClusterPlanning"
,
ClusterPlanning
);
refl
::
GlobalDef
().
def
(
"tl.transform.ClusterPlanning"
,
ClusterPlanning
);
}
);
}
}
// namespace transform
}
// namespace transform
}
// namespace tir
}
// namespace tir
...
...
src/transform/common/loop_parallel_transform_utils.h
View file @
bbbf4207
...
@@ -41,7 +41,7 @@ public:
...
@@ -41,7 +41,7 @@ public:
return
StmtMutator
::
VisitStmt_
(
op
);
return
StmtMutator
::
VisitStmt_
(
op
);
// Collect loop variables and ranges
// Collect loop variables and ranges
auto
for_node
=
GetRef
<
For
>
(
op
);
auto
for_node
=
tvm
::
ffi
::
GetRef
<
For
>
(
op
);
Array
<
Var
>
loop_vars
;
Array
<
Var
>
loop_vars
;
Array
<
PrimExpr
>
loop_extents
;
Array
<
PrimExpr
>
loop_extents
;
Stmt
body
=
op
->
body
;
Stmt
body
=
op
->
body
;
...
@@ -81,7 +81,7 @@ public:
...
@@ -81,7 +81,7 @@ public:
// post order visit the index
// post order visit the index
PostOrderVisit
(
index
,
[
&
](
const
ObjectRef
&
obj
)
{
PostOrderVisit
(
index
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
VarNode
*
v
=
obj
.
as
<
VarNode
>
())
{
if
(
const
VarNode
*
v
=
obj
.
as
<
VarNode
>
())
{
used_vars
.
insert
(
GetRef
<
Var
>
(
v
));
used_vars
.
insert
(
tvm
::
ffi
::
GetRef
<
Var
>
(
v
));
}
}
});
});
if
(
used_vars
.
empty
())
{
if
(
used_vars
.
empty
())
{
...
...
src/transform/common/loop_vectorization_utils.h
View file @
bbbf4207
...
@@ -211,7 +211,7 @@ public:
...
@@ -211,7 +211,7 @@ public:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
bool
is_vec_a
=
a
.
dtype
().
is_scalable_or_fixed_length_vector
();
bool
is_vec_a
=
a
.
dtype
().
is_scalable_or_fixed_length_vector
();
bool
is_vec_b
=
b
.
dtype
().
is_scalable_or_fixed_length_vector
();
bool
is_vec_b
=
b
.
dtype
().
is_scalable_or_fixed_length_vector
();
...
@@ -265,7 +265,7 @@ public:
...
@@ -265,7 +265,7 @@ public:
PrimExpr
VisitExpr_
(
const
NotNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
NotNode
*
op
)
final
{
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
if
(
a
.
same_as
(
op
->
a
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
return
!
(
a
);
return
!
(
a
);
}
}
...
@@ -306,10 +306,10 @@ public:
...
@@ -306,10 +306,10 @@ public:
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
if
(
value
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
if
(
value
.
same_as
(
op
->
value
))
{
if
(
value
.
same_as
(
op
->
value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
return
Broadcast
(
op
->
value
,
op
->
lanes
);
return
Broadcast
(
op
->
value
,
op
->
lanes
);
}
}
...
@@ -321,7 +321,7 @@ public:
...
@@ -321,7 +321,7 @@ public:
PrimExpr
f
=
this
->
VisitExpr
(
op
->
false_value
);
PrimExpr
f
=
this
->
VisitExpr
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
f
.
same_as
(
op
->
false_value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
int
cond_lanes
=
cond
.
dtype
().
get_lanes_or_vscale_factor
();
int
cond_lanes
=
cond
.
dtype
().
get_lanes_or_vscale_factor
();
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
...
@@ -339,7 +339,7 @@ public:
...
@@ -339,7 +339,7 @@ public:
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
if
(
value
.
same_as
(
op
->
value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
if
(
value
.
dtype
().
is_scalable_vector
())
{
if
(
value
.
dtype
().
is_scalable_vector
())
{
return
Cast
(
op
->
dtype
.
with_scalable_vscale_factor
(
return
Cast
(
op
->
dtype
.
with_scalable_vscale_factor
(
...
@@ -352,20 +352,20 @@ public:
...
@@ -352,20 +352,20 @@ public:
}
}
PrimExpr
VisitExpr_
(
const
FloatImmNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
FloatImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
StringImmNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
StringImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
// Variable
// Variable
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
Var
var
=
GetRef
<
Var
>
(
op
);
Var
var
=
tvm
::
ffi
::
GetRef
<
Var
>
(
op
);
if
(
var
.
same_as
(
var_
))
{
if
(
var
.
same_as
(
var_
))
{
return
ramp_
;
return
ramp_
;
...
@@ -382,13 +382,13 @@ public:
...
@@ -382,13 +382,13 @@ public:
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
cond
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
if
(
cond
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
t
=
this
->
VisitExpr
(
op
->
args
[
1
]);
PrimExpr
t
=
this
->
VisitExpr
(
op
->
args
[
1
]);
PrimExpr
f
=
this
->
VisitExpr
(
op
->
args
[
2
]);
PrimExpr
f
=
this
->
VisitExpr
(
op
->
args
[
2
]);
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
t
.
same_as
(
op
->
args
[
1
])
&&
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
t
.
same_as
(
op
->
args
[
1
])
&&
f
.
same_as
(
op
->
args
[
2
]))
{
f
.
same_as
(
op
->
args
[
2
]))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
int
f_lanes
=
f
.
dtype
().
get_lanes_or_vscale_factor
();
int
f_lanes
=
f
.
dtype
().
get_lanes_or_vscale_factor
();
...
@@ -410,7 +410,7 @@ public:
...
@@ -410,7 +410,7 @@ public:
ICHECK
(
op
->
op
.
same_as
(
builtin
::
reinterpret
()));
ICHECK
(
op
->
op
.
same_as
(
builtin
::
reinterpret
()));
PrimExpr
value
=
this
->
VisitExpr
(
op
->
args
[
0
]);
PrimExpr
value
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
value
.
same_as
(
op
->
args
[
0
]))
{
if
(
value
.
same_as
(
op
->
args
[
0
]))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
int
lanes
=
value
.
dtype
().
get_lanes_or_vscale_factor
();
int
lanes
=
value
.
dtype
().
get_lanes_or_vscale_factor
();
if
(
value
.
dtype
().
is_scalable_vector
())
{
if
(
value
.
dtype
().
is_scalable_vector
())
{
...
@@ -455,12 +455,12 @@ public:
...
@@ -455,12 +455,12 @@ public:
auto
new_arg
=
this
->
VisitExpr
(
arg
);
auto
new_arg
=
this
->
VisitExpr
(
arg
);
if
(
new_arg
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
if
(
new_arg
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
new_args
.
push_back
(
new_arg
);
new_args
.
push_back
(
new_arg
);
}
}
if
(
op
->
args
.
same_as
(
new_args
))
{
if
(
op
->
args
.
same_as
(
new_args
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
}
...
@@ -469,7 +469,7 @@ public:
...
@@ -469,7 +469,7 @@ public:
Array
<
PrimExpr
>
new_args
=
MutateArray
(
op
->
args
,
&
lane
);
Array
<
PrimExpr
>
new_args
=
MutateArray
(
op
->
args
,
&
lane
);
// normal code path.
// normal code path.
if
(
op
->
args
.
same_as
(
new_args
))
{
if
(
op
->
args
.
same_as
(
new_args
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
return
Call
(
op
->
dtype
.
with_lanes
(
lane
),
op
->
op
,
new_args
);
return
Call
(
op
->
dtype
.
with_lanes
(
lane
),
op
->
op
,
new_args
);
}
}
...
@@ -477,7 +477,7 @@ public:
...
@@ -477,7 +477,7 @@ public:
}
}
// BufferLoad
// BufferLoad
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
GetRef
<
BufferLoad
>
(
op
);
auto
load
=
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
return
this
->
VisitExpr
(
index
);
...
@@ -514,7 +514,7 @@ public:
...
@@ -514,7 +514,7 @@ public:
let_binding_
[
op
->
var
]
=
op
->
var
;
let_binding_
[
op
->
var
]
=
op
->
var
;
PrimExpr
body
=
this
->
VisitExpr
(
op
->
body
);
PrimExpr
body
=
this
->
VisitExpr
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
return
Let
(
op
->
var
,
value
,
body
);
return
Let
(
op
->
var
,
value
,
body
);
}
}
...
@@ -522,7 +522,7 @@ public:
...
@@ -522,7 +522,7 @@ public:
}
}
// BufferStore
// BufferStore
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
GetRef
<
BufferStore
>
(
op
);
auto
store
=
tvm
::
ffi
::
GetRef
<
BufferStore
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
return
this
->
VisitExpr
(
index
);
...
@@ -585,11 +585,11 @@ public:
...
@@ -585,11 +585,11 @@ public:
ICHECK
(
!
op
->
extent
.
dtype
().
is_scalable_or_fixed_length_vector
());
ICHECK
(
!
op
->
extent
.
dtype
().
is_scalable_or_fixed_length_vector
());
PrimExpr
extent
=
this
->
VisitExpr
(
op
->
extent
);
PrimExpr
extent
=
this
->
VisitExpr
(
op
->
extent
);
if
(
extent
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
if
(
extent
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
}
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
if
(
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
if
(
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
}
else
{
return
For
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
kind
,
body
,
return
For
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
);
op
->
thread_binding
,
op
->
annotations
);
...
@@ -600,7 +600,7 @@ public:
...
@@ -600,7 +600,7 @@ public:
ICHECK
(
!
op
->
condition
.
dtype
().
is_scalable_or_fixed_length_vector
());
ICHECK
(
!
op
->
condition
.
dtype
().
is_scalable_or_fixed_length_vector
());
PrimExpr
condition
=
this
->
VisitExpr
(
op
->
condition
);
PrimExpr
condition
=
this
->
VisitExpr
(
op
->
condition
);
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
}
Stmt
then_case
=
this
->
VisitStmt
(
op
->
then_case
);
Stmt
then_case
=
this
->
VisitStmt
(
op
->
then_case
);
Optional
<
Stmt
>
else_case
=
std
::
nullopt
;
Optional
<
Stmt
>
else_case
=
std
::
nullopt
;
...
@@ -609,7 +609,7 @@ public:
...
@@ -609,7 +609,7 @@ public:
}
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
else_case
.
same_as
(
op
->
else_case
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
}
else
{
return
IfThenElse
(
condition
,
then_case
,
else_case
);
return
IfThenElse
(
condition
,
then_case
,
else_case
);
}
}
...
@@ -634,7 +634,7 @@ public:
...
@@ -634,7 +634,7 @@ public:
let_binding_
[
op
->
var
]
=
op
->
var
;
let_binding_
[
op
->
var
]
=
op
->
var
;
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
}
else
{
return
LetStmt
(
op
->
var
,
value
,
body
);
return
LetStmt
(
op
->
var
,
value
,
body
);
}
}
...
@@ -647,7 +647,7 @@ public:
...
@@ -647,7 +647,7 @@ public:
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
<<
op
->
buffer_var
->
name_hint
;
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
}
// Mutate the extents
// Mutate the extents
...
@@ -657,7 +657,7 @@ public:
...
@@ -657,7 +657,7 @@ public:
if
(
new_ext
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
if
(
new_ext
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
<<
op
->
buffer_var
->
name_hint
;
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
}
extents
.
push_back
(
new_ext
);
extents
.
push_back
(
new_ext
);
}
}
...
@@ -738,7 +738,7 @@ private:
...
@@ -738,7 +738,7 @@ private:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
...
@@ -754,7 +754,7 @@ private:
...
@@ -754,7 +754,7 @@ private:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
}
else
{
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
src/transform/config_index_bitwidth.cc
View file @
bbbf4207
...
@@ -38,7 +38,7 @@ protected:
...
@@ -38,7 +38,7 @@ protected:
if
(
is_enabled_
&&
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
if
(
is_enabled_
&&
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
IntImm
(
DataType
::
Int
(
_index_bitwidth_
),
op
->
value
);
return
IntImm
(
DataType
::
Int
(
_index_bitwidth_
),
op
->
value
);
}
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
...
@@ -88,23 +88,23 @@ private:
...
@@ -88,23 +88,23 @@ private:
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
GetRef
<
Var
>
(
op
));
return
cast
(
DataType
::
Int
(
64
),
tvm
::
ffi
::
GetRef
<
Var
>
(
op
));
}
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
IntImm
(
DataType
::
Int
(
64
),
op
->
value
);
return
IntImm
(
DataType
::
Int
(
64
),
op
->
value
);
}
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
op
->
value
);
return
cast
(
DataType
::
Int
(
64
),
op
->
value
);
}
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
...
@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
...
@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ConfigIndexBitwidth"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ConfigIndexBitwidth"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.ConfigIndexBitwidth"
,
refl
::
GlobalDef
().
def
(
"tl.transform.ConfigIndexBitwidth"
,
ConfigIndexBitwidth
);
ConfigIndexBitwidth
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/eliminate_storage_sync_for_mbarrier.cc
View file @
bbbf4207
...
@@ -35,9 +35,7 @@ public:
...
@@ -35,9 +35,7 @@ public:
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
"thread_extent"
)
{
if
(
op
->
attr_key
==
"thread_extent"
)
{
const
VarNode
*
var
=
nullptr
;
if
(
const
auto
*
var
=
op
->
node
.
as
<
VarNode
>
())
{
if
(
op
->
node
->
IsInstance
<
VarNode
>
())
{
var
=
op
->
node
.
as
<
VarNode
>
();
if
(
var
->
name_hint
==
"threadIdx.x"
)
{
if
(
var
->
name_hint
==
"threadIdx.x"
)
{
thread_extent_
=
op
;
thread_extent_
=
op
;
}
}
...
@@ -82,7 +80,7 @@ public:
...
@@ -82,7 +80,7 @@ public:
}
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
PostOrderVisit
(
GetRef
<
For
>
(
op
),
[
&
](
const
ObjectRef
&
node
)
{
PostOrderVisit
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
),
[
&
](
const
ObjectRef
&
node
)
{
if
(
const
auto
*
call
=
node
.
as
<
CallNode
>
())
{
if
(
const
auto
*
call
=
node
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
create_list_of_mbarrier
())
||
if
(
call
->
op
.
same_as
(
create_list_of_mbarrier
())
||
call
->
op
.
same_as
(
mbarrier_wait_parity
())
||
call
->
op
.
same_as
(
mbarrier_wait_parity
())
||
...
@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
...
@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{});
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.EliminateStorageSyncForMBarrier"
,
refl
::
GlobalDef
().
def
(
"tl.transform.EliminateStorageSyncForMBarrier"
,
EliminateStorageSyncForMBarrier
);
EliminateStorageSyncForMBarrier
);
}
);
}
}
// namespace transform
}
// namespace transform
}
// namespace tl
}
// namespace tl
...
...
src/transform/flatten_buffer.cc
View file @
bbbf4207
...
@@ -75,23 +75,23 @@ private:
...
@@ -75,23 +75,23 @@ private:
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
GetRef
<
Var
>
(
op
));
return
cast
(
DataType
::
Int
(
64
),
tvm
::
ffi
::
GetRef
<
Var
>
(
op
));
}
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
IntImm
(
DataType
::
Int
(
64
),
op
->
value
);
return
IntImm
(
DataType
::
Int
(
64
),
op
->
value
);
}
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
op
->
value
);
return
cast
(
DataType
::
Int
(
64
),
op
->
value
);
}
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
...
@@ -115,7 +115,7 @@ private:
...
@@ -115,7 +115,7 @@ private:
<<
"All MatchBufferRegion should be removed in "
<<
"All MatchBufferRegion should be removed in "
"tir.transform.LowerMatchBuffer."
;
"tir.transform.LowerMatchBuffer."
;
Block
block
=
GetRef
<
Block
>
(
op
);
Block
block
=
tvm
::
ffi
::
GetRef
<
Block
>
(
op
);
Array
<
Buffer
>
alloc_buffers
=
op
->
alloc_buffers
;
Array
<
Buffer
>
alloc_buffers
=
op
->
alloc_buffers
;
alloc_buffers
.
MutateByApply
(
alloc_buffers
.
MutateByApply
(
...
@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() {
...
@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.FlattenBuffer"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.FlattenBuffer"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.FlattenBuffer"
,
FlattenBuffer
);
refl
::
GlobalDef
().
def
(
"tl.transform.FlattenBuffer"
,
FlattenBuffer
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/frontend_legalize.cc
View file @
bbbf4207
...
@@ -89,10 +89,10 @@ Pass LetInline() {
...
@@ -89,10 +89,10 @@ Pass LetInline() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LetInline"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LetInline"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LetInline"
,
LetInline
);
refl
::
GlobalDef
().
def
(
"tl.transform.LetInline"
,
LetInline
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/if_stmt_binding.cc
View file @
bbbf4207
...
@@ -33,7 +33,7 @@ private:
...
@@ -33,7 +33,7 @@ private:
auto
then_case
=
VisitStmt
(
op
->
then_case
);
auto
then_case
=
VisitStmt
(
op
->
then_case
);
Optional
<
Stmt
>
else_case
=
op
->
else_case
;
Optional
<
Stmt
>
else_case
=
op
->
else_case
;
if
(
else_case
.
defined
())
{
if
(
else_case
.
defined
())
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
}
ICHECK
(
then_case
.
defined
())
<<
"then_case must be defined"
;
ICHECK
(
then_case
.
defined
())
<<
"then_case must be defined"
;
ICHECK
(
!
else_case
.
defined
())
<<
"else_case must be undefined"
;
ICHECK
(
!
else_case
.
defined
())
<<
"else_case must be undefined"
;
...
@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
...
@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.IfStmtBinding"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.IfStmtBinding"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.IfStmtBinding"
,
IfStmtBinding
);
refl
::
GlobalDef
().
def
(
"tl.transform.IfStmtBinding"
,
IfStmtBinding
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/inject_assumes.cc
View file @
bbbf4207
...
@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() {
...
@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectAssumes"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectAssumes"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectAssumes"
,
InjectAssumes
);
refl
::
GlobalDef
().
def
(
"tl.transform.InjectAssumes"
,
InjectAssumes
);
}
);
}
}
// namespace tvm::tl
}
// namespace tvm::tl
src/transform/inject_fence_proxy.cc
View file @
bbbf4207
...
@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) {
...
@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) {
return
false
;
return
false
;
}
}
return
call
->
op
.
same_as
(
ptx_ldmatrix
())
||
call
->
op
.
same_as
(
ptx_stmatrix
())
||
return
call
->
op
.
same_as
(
ptx_ldmatrix
())
||
call
->
op
.
same_as
(
ptx_stmatrix
())
||
call
->
op
.
same_as
(
initialize_descriptor
());
call
->
op
.
same_as
(
initialize_wgmma_descriptor
())
||
call
->
op
.
same_as
(
initialize_tcgen05_descriptor
());
}
}
ProxyKind
ProxyFromAttrValue
(
const
ObjectRef
&
value
)
{
ProxyKind
ProxyFromAttrValue
(
const
ObjectRef
&
value
)
{
...
@@ -319,10 +320,10 @@ tvm::transform::Pass InjectFenceProxy() {
...
@@ -319,10 +320,10 @@ tvm::transform::Pass InjectFenceProxy() {
{});
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectFenceProxy"
,
InjectFenceProxy
);
refl
::
GlobalDef
().
def
(
"tl.transform.InjectFenceProxy"
,
InjectFenceProxy
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/inject_pipeline.cc
View file @
bbbf4207
...
@@ -37,9 +37,14 @@
...
@@ -37,9 +37,14 @@
namespace
tvm
{
namespace
tvm
{
namespace
tl
{
namespace
tl
{
using
namespace
tir
;
using
namespace
tir
;
using
namespace
ffi
;
namespace
software_pipeline
{
namespace
software_pipeline
{
struct
LetWrapper
{
Var
var
;
PrimExpr
value
;
};
/*!
/*!
* \brief Create a block and infer the access region with the given body.
* \brief Create a block and infer the access region with the given body.
*
*
...
@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator {
...
@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator {
public:
public:
PipelineRewriter
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
PipelineRewriter
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
Array
<
Buffer
>
&
pipeline_allocs
,
const
Array
<
Buffer
>
&
pipeline_allocs
,
const
For
&
pipeline_loop
,
const
PipelineInfo
&
pipeline_info
)
const
For
&
pipeline_loop
,
const
PipelineInfo
&
pipeline_info
,
const
std
::
vector
<
LetWrapper
>
&
loop_var_let_wrappers
)
:
buffer_data_to_buffer_
(
std
::
move
(
buffer_data_to_buffer
)),
:
buffer_data_to_buffer_
(
std
::
move
(
buffer_data_to_buffer
)),
pipeline_allocs_
(
pipeline_allocs
),
pipeline_loop_
(
pipeline_loop
),
pipeline_allocs_
(
pipeline_allocs
),
pipeline_loop_
(
pipeline_loop
),
pipeline_info_
(
pipeline_info
)
{}
pipeline_info_
(
pipeline_info
),
loop_var_let_wrappers_
(
loop_var_let_wrappers
)
{}
Stmt
BuildPipeline
()
{
Stmt
BuildPipeline
()
{
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
...
@@ -459,7 +466,8 @@ private:
...
@@ -459,7 +466,8 @@ private:
* \return The resized buffer.
* \return The resized buffer.
*/
*/
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
ObjectPtr
<
BufferNode
>
new_buffer
=
tvm
::
ffi
::
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
if
(
!
new_buffer
->
strides
.
empty
())
{
if
(
!
new_buffer
->
strides
.
empty
())
{
ICHECK
(
new_buffer
->
strides
.
size
()
+
1
==
new_buffer
->
shape
.
size
());
ICHECK
(
new_buffer
->
strides
.
size
()
+
1
==
new_buffer
->
shape
.
size
());
...
@@ -676,6 +684,20 @@ private:
...
@@ -676,6 +684,20 @@ private:
new_block
=
Downcast
<
Block
>
(
Substitute
(
new_block
=
Downcast
<
Block
>
(
Substitute
(
new_block
,
{{
pipeline_loop_
->
loop_var
,
normalized_access_index
}}));
new_block
,
{{
pipeline_loop_
->
loop_var
,
normalized_access_index
}}));
// If there were Let-wrappers outside the original pipeline body that
// depended on the pipeline loop var, push them into each rewritten
// block with the correct per-block substitution.
if
(
!
loop_var_let_wrappers_
.
empty
())
{
BlockNode
*
n
=
new_block
.
CopyOnWrite
();
Stmt
inner
=
n
->
body
;
for
(
const
auto
&
lw
:
loop_var_let_wrappers_
)
{
PrimExpr
substituted
=
Substitute
(
lw
.
value
,
{{
pipeline_loop_
->
loop_var
,
normalized_access_index
}});
inner
=
LetStmt
(
lw
.
var
,
substituted
,
inner
);
}
n
->
body
=
inner
;
}
if
(
pipeline_info_
[
block
].
async
)
{
if
(
pipeline_info_
[
block
].
async
)
{
auto
&
local_state
=
async_states_local
[
stage
];
auto
&
local_state
=
async_states_local
[
stage
];
local_state
.
producer_head
=
normalized_access_index
;
local_state
.
producer_head
=
normalized_access_index
;
...
@@ -737,6 +759,7 @@ private:
...
@@ -737,6 +759,7 @@ private:
Map
<
Buffer
,
Buffer
>
buffer_remap_
;
Map
<
Buffer
,
Buffer
>
buffer_remap_
;
Array
<
Block
>
ordered_stmts_
;
Array
<
Block
>
ordered_stmts_
;
std
::
map
<
int
,
AsyncStateGlobal
>
async_states
;
std
::
map
<
int
,
AsyncStateGlobal
>
async_states
;
std
::
vector
<
LetWrapper
>
loop_var_let_wrappers_
;
};
};
/*!
/*!
...
@@ -864,8 +887,9 @@ private:
...
@@ -864,8 +887,9 @@ private:
const
SeqStmtNode
*
pipeline_body_seq
=
nullptr
;
const
SeqStmtNode
*
pipeline_body_seq
=
nullptr
;
std
::
vector
<
std
::
function
<
Stmt
(
Stmt
)
>>
rewrap_fns
;
std
::
vector
<
std
::
function
<
Stmt
(
Stmt
)
>>
rewrap_fns
;
std
::
vector
<
LetWrapper
>
loop_var_let_wrappers
;
auto
append_attr_wrapper
=
[
&
rewrap_fns
](
const
AttrStmtNode
*
attr
)
{
auto
append_attr_wrapper
=
[
&
rewrap_fns
](
const
AttrStmtNode
*
attr
)
{
ObjectRef
node
=
attr
->
node
;
Any
node
=
attr
->
node
;
String
attr_key
=
attr
->
attr_key
;
String
attr_key
=
attr
->
attr_key
;
PrimExpr
value
=
attr
->
value
;
PrimExpr
value
=
attr
->
value
;
Span
span
=
attr
->
span
;
Span
span
=
attr
->
span
;
...
@@ -896,14 +920,25 @@ private:
...
@@ -896,14 +920,25 @@ private:
continue
;
continue
;
}
}
if
(
const
auto
*
let_stmt
=
current
.
as
<
LetStmtNode
>
())
{
if
(
const
auto
*
let_stmt
=
current
.
as
<
LetStmtNode
>
())
{
Var
var
=
let_stmt
->
var
;
// If this Let value uses the pipeline loop var, record it and push
PrimExpr
value
=
let_stmt
->
value
;
// inside each rewritten block later so the loop var can be
Span
span
=
let_stmt
->
span
;
// substituted with the correct per-iteration index. Otherwise, keep
rewrap_fns
.
emplace_back
([
var
=
std
::
move
(
var
),
// it as a normal wrapper.
value
=
std
::
move
(
value
),
bool
uses_loop_var
=
UsesVar
(
span
](
Stmt
body
)
->
Stmt
{
let_stmt
->
value
,
return
LetStmt
(
var
,
value
,
body
,
span
);
[
v
=
op
->
loop_var
.
get
()](
const
VarNode
*
vn
)
{
return
vn
==
v
;
});
});
if
(
uses_loop_var
)
{
loop_var_let_wrappers
.
push_back
({
let_stmt
->
var
,
let_stmt
->
value
});
}
else
{
Var
var
=
let_stmt
->
var
;
PrimExpr
value
=
let_stmt
->
value
;
Span
span
=
let_stmt
->
span
;
rewrap_fns
.
emplace_back
([
var
=
std
::
move
(
var
),
value
=
std
::
move
(
value
),
span
](
Stmt
body
)
->
Stmt
{
return
LetStmt
(
var
,
value
,
body
,
span
);
});
}
current
=
let_stmt
->
body
;
current
=
let_stmt
->
body
;
continue
;
continue
;
}
}
...
@@ -981,7 +1016,8 @@ private:
...
@@ -981,7 +1016,8 @@ private:
// Step 4: Rewrite the pipeline body.
// Step 4: Rewrite the pipeline body.
Stmt
pipeline
=
PipelineRewriter
(
buffer_data_to_buffer_
,
pipeline_allocs
,
Stmt
pipeline
=
PipelineRewriter
(
buffer_data_to_buffer_
,
pipeline_allocs
,
GetRef
<
For
>
(
op
),
pipeline_info
)
tvm
::
ffi
::
GetRef
<
For
>
(
op
),
pipeline_info
,
loop_var_let_wrappers
)
.
BuildPipeline
();
.
BuildPipeline
();
auto
apply_wrappers
=
[
&
](
Stmt
stmt
)
{
auto
apply_wrappers
=
[
&
](
Stmt
stmt
)
{
for
(
auto
it
=
rewrap_fns
.
rbegin
();
it
!=
rewrap_fns
.
rend
();
++
it
)
{
for
(
auto
it
=
rewrap_fns
.
rbegin
();
it
!=
rewrap_fns
.
rend
();
++
it
)
{
...
@@ -1072,11 +1108,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
...
@@ -1072,11 +1108,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectSoftwarePipeline"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectSoftwarePipeline"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectSoftwarePipeline"
,
refl
::
GlobalDef
().
def
(
"tl.transform.InjectSoftwarePipeline"
,
InjectSoftwarePipeline
);
InjectSoftwarePipeline
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/inject_ptx_async_copy.cc
View file @
bbbf4207
...
@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
...
@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectPTXAsyncCopy"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectPTXAsyncCopy"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectPTXAsyncCopy"
,
InjectPTXAsyncCopy
);
refl
::
GlobalDef
().
def
(
"tl.transform.InjectPTXAsyncCopy"
,
InjectPTXAsyncCopy
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/inject_tma_barrier.cc
View file @
bbbf4207
...
@@ -204,9 +204,9 @@ private:
...
@@ -204,9 +204,9 @@ private:
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
if
(
const
auto
*
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
const
auto
*
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
tma_load
())
||
call
->
op
.
same_as
(
tma_load_im2col
()))
{
if
(
call
->
op
.
same_as
(
tma_load
())
||
call
->
op
.
same_as
(
tma_load_im2col
()))
{
pending_tma_ops_
.
push_back
(
GetRef
<
Call
>
(
call
));
pending_tma_ops_
.
push_back
(
tvm
::
ffi
::
GetRef
<
Call
>
(
call
));
}
else
if
(
call
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
}
else
if
(
call
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
pending_tma_ops_
.
push_back
(
GetRef
<
Call
>
(
call
));
pending_tma_ops_
.
push_back
(
tvm
::
ffi
::
GetRef
<
Call
>
(
call
));
}
else
if
(
call
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
}
else
if
(
call
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
PrimExpr
barrier_id
=
call
->
args
[
0
];
PrimExpr
barrier_id
=
call
->
args
[
0
];
for
(
const
auto
&
tma_call
:
pending_tma_ops_
)
{
for
(
const
auto
&
tma_call
:
pending_tma_ops_
)
{
...
@@ -295,13 +295,15 @@ public:
...
@@ -295,13 +295,15 @@ public:
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
if
(
op
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
PrimExpr
e
=
auto
call_ref
=
tvm
::
ffi
::
GetRef
<
Call
>
(
op
);
tma_op_to_barrier_id_
[
GetRef
<
Call
>
(
op
)].
as
<
CallNode
>
()
->
args
[
0
];
if
(
tma_op_to_barrier_id_
.
count
(
call_ref
))
{
auto
int_set
=
arith
::
EvalSet
(
e
,
var_int_set_
);
PrimExpr
e
=
tma_op_to_barrier_id_
[
call_ref
].
as
<
CallNode
>
()
->
args
[
0
];
expect_
.
push_back
(
if_depth_
==
1
);
auto
int_set
=
arith
::
EvalSet
(
e
,
var_int_set_
);
sequence
.
push_back
(
0
);
expect_
.
push_back
(
if_depth_
==
1
);
int_sets_
.
push_back
(
int_set
);
sequence
.
push_back
(
0
);
expect_tx_count_
+=
1
;
int_sets_
.
push_back
(
int_set
);
expect_tx_count_
+=
1
;
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
sequence
.
push_back
(
1
);
sequence
.
push_back
(
1
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async_barrier
()))
{
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async_barrier
()))
{
...
@@ -336,32 +338,61 @@ public:
...
@@ -336,32 +338,61 @@ public:
class
BarrierCreationRewriter
:
public
StmtExprMutator
{
class
BarrierCreationRewriter
:
public
StmtExprMutator
{
public:
public:
BarrierCreationRewriter
(
std
::
vector
<
int
>
restore_barrier_ids
,
BarrierCreationRewriter
(
std
::
vector
<
int
>
restore_barrier_ids
,
PrimExpr
producer_thread_extent
)
PrimExpr
producer_thread_extent
,
int
ensure_min_count
=
0
,
PrimExpr
default_barrier_thread_count
=
1
)
:
restore_barrier_ids_
(
std
::
move
(
restore_barrier_ids
)),
:
restore_barrier_ids_
(
std
::
move
(
restore_barrier_ids
)),
producer_thread_extent_
(
std
::
move
(
producer_thread_extent
))
{}
producer_thread_extent_
(
std
::
move
(
producer_thread_extent
)),
ensure_min_count_
(
ensure_min_count
),
default_barrier_thread_count_
(
std
::
move
(
default_barrier_thread_count
))
{
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
if
(
op
->
op
.
same_as
(
create_list_of_mbarrier
()))
{
if
(
op
->
op
.
same_as
(
create_list_of_mbarrier
()))
{
std
::
vector
<
bool
>
tmp_
(
op
->
args
.
size
(),
false
);
size_t
cur_n
=
op
->
args
.
size
();
Array
<
PrimExpr
>
new_args
;
size_t
need_n
=
std
::
max
<
size_t
>
(
cur_n
,
static_cast
<
size_t
>
(
ensure_min_count_
));
// Mark barriers to restore across the full needed length, not just the
// original length, so newly appended entries can be restored as well.
std
::
vector
<
bool
>
replace
(
need_n
,
false
);
for
(
auto
&
id
:
restore_barrier_ids_
)
{
for
(
auto
&
id
:
restore_barrier_ids_
)
{
tmp_
[
id
]
=
true
;
if
(
id
>=
0
&&
static_cast
<
size_t
>
(
id
)
<
replace
.
size
())
{
replace
[
id
]
=
true
;
}
}
}
for
(
size_t
i
{
0
};
i
<
op
->
args
.
size
();
++
i
)
{
Array
<
PrimExpr
>
new_args
;
if
(
tmp_
[
i
])
{
new_args
.
reserve
(
need_n
);
// Preserve/override existing entries
for
(
size_t
i
{
0
};
i
<
cur_n
;
++
i
)
{
if
(
replace
[
i
])
{
new_args
.
push_back
(
producer_thread_extent_
);
new_args
.
push_back
(
producer_thread_extent_
);
}
else
{
}
else
{
new_args
.
push_back
(
op
->
args
[
i
]);
new_args
.
push_back
(
op
->
args
[
i
]);
}
}
}
}
// Append additional barriers if required
for
(
size_t
i
=
cur_n
;
i
<
need_n
;
++
i
)
{
if
(
replace
[
i
])
{
new_args
.
push_back
(
producer_thread_extent_
);
}
else
{
new_args
.
push_back
(
default_barrier_thread_count_
);
}
}
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
else
{
}
else
{
return
StmtExprMutator
::
VisitExpr_
(
op
);
return
StmtExprMutator
::
VisitExpr_
(
op
);
}
}
}
}
private:
std
::
vector
<
int
>
restore_barrier_ids_
;
std
::
vector
<
int
>
restore_barrier_ids_
;
PrimExpr
producer_thread_extent_
;
PrimExpr
producer_thread_extent_
;
int
ensure_min_count_
{
0
};
PrimExpr
default_barrier_thread_count_
{
1
};
};
};
// we trust mbarrier_wait_parity to be correct
// we trust mbarrier_wait_parity to be correct
...
@@ -398,15 +429,38 @@ public:
...
@@ -398,15 +429,38 @@ public:
collector
.
barrier_id_to_range
(),
collector
.
barrier_id_to_range
(),
has_create_list_of_mbarrier
);
has_create_list_of_mbarrier
);
f
.
CopyOnWrite
()
->
body
=
rewriter
(
f
->
body
);
f
.
CopyOnWrite
()
->
body
=
rewriter
(
f
->
body
);
// Compute the minimum number of barriers actually referenced in the body
// after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
struct
GetMbarrierMaxIdxCollector
:
public
StmtExprVisitor
{
int
max_idx
{
-
1
};
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
get_mbarrier
()))
{
if
(
op
->
args
.
size
()
==
1
)
{
if
(
const
auto
*
imm
=
op
->
args
[
0
].
as
<
IntImmNode
>
())
{
max_idx
=
std
::
max
(
max_idx
,
static_cast
<
int
>
(
imm
->
value
));
}
}
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
};
GetMbarrierMaxIdxCollector
max_idx_collector
;
max_idx_collector
(
f
->
body
);
int
ensure_min_count
=
max_idx_collector
.
max_idx
+
1
;
// 0-based -> count
// For simple TMA-only producers, default barrier arrive count should be 1
// (only the elected leader performs the TMA arrive/expect).
auto
barrier_creation_rewriter
=
BarrierCreationRewriter
(
auto
barrier_creation_rewriter
=
BarrierCreationRewriter
(
rewriter
.
restore_barrier_ids_
,
rewriter
.
producer_thread_extent_
);
rewriter
.
restore_barrier_ids_
,
rewriter
.
producer_thread_extent_
,
ensure_min_count
,
Integer
(
1
));
f
.
CopyOnWrite
()
->
body
=
barrier_creation_rewriter
(
f
->
body
);
f
.
CopyOnWrite
()
->
body
=
barrier_creation_rewriter
(
f
->
body
);
return
f
;
return
f
;
}
}
private:
private:
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
{
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
{
auto
block
=
GetRef
<
Block
>
(
op
);
auto
block
=
tvm
::
ffi
::
GetRef
<
Block
>
(
op
);
if
(
!
has_create_list_of_mbarrier_
&&
!
barrier_id_to_range_
.
empty
()
&&
if
(
!
has_create_list_of_mbarrier_
&&
!
barrier_id_to_range_
.
empty
()
&&
op
->
name_hint
==
MainBlockName
)
{
op
->
name_hint
==
MainBlockName
)
{
ICHECK
(
false
)
<<
"Please declare create_list_of_mbarrier."
;
ICHECK
(
false
)
<<
"Please declare create_list_of_mbarrier."
;
...
@@ -452,10 +506,27 @@ private:
...
@@ -452,10 +506,27 @@ private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
if
(
op
->
op
.
same_as
(
tma_load
())
||
op
->
op
.
same_as
(
tma_load_im2col
()))
{
if
(
op
->
op
.
same_as
(
tma_load
())
||
op
->
op
.
same_as
(
tma_load_im2col
()))
{
// check this must be in the tma_op_to_barrier_id_
auto
call_ref
=
tvm
::
ffi
::
GetRef
<
Call
>
(
op
);
ICHECK
(
tma_op_to_barrier_id_
.
count
(
GetRef
<
Call
>
(
op
)))
if
(
!
tma_op_to_barrier_id_
.
count
(
call_ref
))
{
<<
"tma_load must be in the tma_op_to_barrier_id_"
;
// For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
auto
barrier_id
=
tma_op_to_barrier_id_
[
GetRef
<
Call
>
(
op
)];
// so codegen can emit mbarrier[index]. This handles degenerate
// producer-only kernels where no arrive() is seen and mapping is empty.
auto
arg0
=
op
->
args
[
0
].
as
<
Call
>
();
bool
is_1d_tma_load
=
arg0
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_descriptor
())
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_im2col_descriptor
());
if
(
is_1d_tma_load
&&
op
->
args
.
size
()
>=
3
)
{
if
(
const
auto
*
imm
=
op
->
args
[
2
].
as
<
IntImmNode
>
())
{
Array
<
PrimExpr
>
new_args
=
op
->
args
;
new_args
.
Set
(
2
,
Call
(
DataType
::
Handle
(),
get_mbarrier
(),
{
IntImm
(
DataType
::
Int
(
32
),
static_cast
<
int
>
(
imm
->
value
))}));
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
}
return
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
);
}
auto
barrier_id
=
tma_op_to_barrier_id_
[
call_ref
];
auto
new_args
=
op
->
args
;
auto
new_args
=
op
->
args
;
auto
arg0
=
op
->
args
[
0
].
as
<
Call
>
();
auto
arg0
=
op
->
args
[
0
].
as
<
Call
>
();
auto
is_1d_tma_load
=
auto
is_1d_tma_load
=
...
@@ -468,9 +539,11 @@ private:
...
@@ -468,9 +539,11 @@ private:
}
}
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
else
if
(
op
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
}
else
if
(
op
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
ICHECK
(
tma_op_to_barrier_id_
.
count
(
GetRef
<
Call
>
(
op
)))
auto
call_ref
=
tvm
::
ffi
::
GetRef
<
Call
>
(
op
);
<<
"mbarrier_expect_tx must be in the tma_op_to_barrier_id_"
;
if
(
!
tma_op_to_barrier_id_
.
count
(
call_ref
))
{
auto
barrier_id
=
tma_op_to_barrier_id_
[
GetRef
<
Call
>
(
op
)];
return
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
);
}
auto
barrier_id
=
tma_op_to_barrier_id_
[
call_ref
];
auto
new_args
=
op
->
args
;
auto
new_args
=
op
->
args
;
new_args
.
Set
(
0
,
barrier_id
);
new_args
.
Set
(
0
,
barrier_id
);
if
(
!
has_warp_specialization_
)
if
(
!
has_warp_specialization_
)
...
@@ -522,10 +595,10 @@ tvm::transform::Pass InjectTmaBarrier() {
...
@@ -522,10 +595,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectTmaBarrier"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectTmaBarrier"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectTmaBarrier"
,
InjectTmaBarrier
);
refl
::
GlobalDef
().
def
(
"tl.transform.InjectTmaBarrier"
,
InjectTmaBarrier
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/layout_inference.cc
View file @
bbbf4207
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <tvm/tir/utils.h>
#include <algorithm>
#include <queue>
#include <queue>
#include "../layout/utils.h"
#include "../layout/utils.h"
...
@@ -105,20 +106,60 @@ public:
...
@@ -105,20 +106,60 @@ public:
"required for layout inference."
;
"required for layout inference."
;
// Run InferLayout
// Run InferLayout
DLOG
(
INFO
)
<<
"[RunInferStep] working on "
<<
cur_infer_id
<<
'\n'
;
auto
updates
=
auto
updates
=
next
->
InferLayout
(
LayoutInferArgs
{
target_
,
thread_bounds
,
layout_map
,
next
->
InferLayout
(
LayoutInferArgs
{
target_
,
thread_bounds
,
layout_map
,
&
analyzer_
,
buffer_oob
},
&
analyzer_
,
buffer_oob
},
level
);
level
);
// Process the returned updates
// Process the returned updates
for
(
const
auto
&
[
buffer
,
layout
]
:
updates
)
{
for
(
const
auto
&
[
buffer
,
layout
]
:
updates
)
{
DLOG
(
INFO
)
<<
" consider update "
<<
buffer
<<
" as "
<<
layout
->
DebugOutput
()
<<
'\n'
;
// Basic validity checks
// Basic validity checks
ICHECK
(
buffer
.
defined
())
<<
"InferLayout returned an undefined buffer."
;
ICHECK
(
buffer
.
defined
())
<<
"InferLayout returned an undefined buffer."
;
ICHECK
(
layout
.
defined
())
<<
"InferLayout returned an undefined layout."
;
ICHECK
(
layout
.
defined
())
<<
"InferLayout returned an undefined layout."
;
// Helper: propagate inferred layout to alias buffers (same data Var)
auto
propagate_alias
=
[
&
](
const
Buffer
&
src_buffer
,
const
Layout
&
src_layout
)
{
if
(
!
buffer_data_to_buffers_
.
count
(
src_buffer
->
data
))
return
;
const
auto
&
siblings
=
buffer_data_to_buffers_
[
src_buffer
->
data
];
for
(
const
auto
&
sib
:
siblings
)
{
if
(
sib
.
same_as
(
src_buffer
))
continue
;
bool
shapes_equal
=
src_layout
->
InputShape
().
size
()
==
sib
->
shape
.
size
();
if
(
shapes_equal
)
{
for
(
size_t
i
=
0
;
i
<
src_layout
->
InputShape
().
size
();
++
i
)
{
if
(
!
analyzer_
.
CanProveEqual
(
src_layout
->
InputShape
()[
i
],
sib
->
shape
[
i
]))
{
shapes_equal
=
false
;
break
;
}
}
}
Layout
target_layout
=
shapes_equal
?
src_layout
:
src_layout
->
Reshape
(
sib
->
shape
,
&
analyzer_
);
if
(
layout_map
.
count
(
sib
))
{
ICHECK
(
target_layout
->
IsEqual
(
layout_map
[
sib
].
get
()))
<<
"Get different layout for alias buffer "
<<
sib
<<
" (data-shared with "
<<
src_buffer
<<
")
\n
current: "
<<
target_layout
->
DebugOutput
()
<<
"
\n
previous: "
<<
layout_map
[
sib
]
->
DebugOutput
();
}
else
{
layout_map
.
Set
(
sib
,
target_layout
);
if
(
update_queue
&&
use_list_
.
count
(
sib
))
{
for
(
int
idx
:
use_list_
[
sib
])
{
if
(
!
in_queue
[
idx
]
&&
idx
!=
cur_infer_id
)
{
in_queue
[
idx
]
=
true
;
q
.
push
(
idx
);
}
}
}
}
}
};
if
(
layout_map
.
count
(
buffer
))
{
if
(
layout_map
.
count
(
buffer
))
{
// If new layout contains the old one, update map
// If new layout contains the old one, update map
if
(
buffer
.
scope
()
==
"local.fragment"
&&
if
(
buffer
.
scope
()
==
"local.fragment"
&&
...
@@ -153,8 +194,8 @@ public:
...
@@ -153,8 +194,8 @@ public:
if
(
ProveFragmentContains
(
src_layout
,
dst_layout
,
indices
,
indices
,
if
(
ProveFragmentContains
(
src_layout
,
dst_layout
,
indices
,
indices
,
inner_analyzer
))
{
inner_analyzer
))
{
layout_map
.
Set
(
buffer
,
layout
);
layout_map
.
Set
(
buffer
,
layout
);
DLOG
(
INFO
)
<<
" layout broadcast from "
// Propagate to alias buffers as well
<<
src_layout
->
DebugOutput
()
<<
", accepted"
<<
'\n'
;
propagate_alias
(
buffer
,
layout
)
;
continue
;
continue
;
}
}
}
}
...
@@ -163,10 +204,13 @@ public:
...
@@ -163,10 +204,13 @@ public:
<<
"Get different layout for "
<<
buffer
<<
"Get different layout for "
<<
buffer
<<
"
\n
current layout: "
<<
layout
->
DebugOutput
()
<<
"
\n
current layout: "
<<
layout
->
DebugOutput
()
<<
"
\n
previous layout: "
<<
layout_map
[
buffer
]
->
DebugOutput
();
<<
"
\n
previous layout: "
<<
layout_map
[
buffer
]
->
DebugOutput
();
// Ensure aliases are consistent too
propagate_alias
(
buffer
,
layout
);
}
else
{
}
else
{
// Otherwise, update map
// Otherwise, update map
layout_map
.
Set
(
buffer
,
layout
);
layout_map
.
Set
(
buffer
,
layout
);
DLOG
(
INFO
)
<<
" new layout accepted"
<<
'\n'
;
// Propagate to alias buffers (may enqueue their users)
propagate_alias
(
buffer
,
layout
);
if
(
!
update_queue
)
if
(
!
update_queue
)
continue
;
continue
;
...
@@ -272,6 +316,46 @@ public:
...
@@ -272,6 +316,46 @@ public:
// step 3: relax constraints to free and re-run
// step 3: relax constraints to free and re-run
InferInFreeMode
(
layout_map
,
strict_layout_map
);
InferInFreeMode
(
layout_map
,
strict_layout_map
);
// step 4: finalize alias layouts by Var
// For each storage var, if any buffer in the group has a layout,
// propagate (reshape if needed) to the rest to ensure completeness.
for
(
const
auto
&
[
var
,
buffers
]
:
buffer_data_to_buffers_
)
{
// Find a representative with existing layout
Optional
<
Buffer
>
rep
;
Optional
<
Layout
>
rep_layout
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
layout_map
.
count
(
buf
))
{
rep
=
buf
;
rep_layout
=
layout_map
[
buf
];
break
;
}
}
if
(
!
rep_layout
.
defined
())
continue
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
!
layout_map
.
count
(
buf
))
{
bool
shapes_equal
=
rep_layout
.
value
()
->
InputShape
().
size
()
==
buf
->
shape
.
size
();
if
(
shapes_equal
)
{
for
(
size_t
i
=
0
;
i
<
rep_layout
.
value
()
->
InputShape
().
size
();
++
i
)
{
if
(
!
analyzer_
.
CanProveEqual
(
rep_layout
.
value
()
->
InputShape
()[
i
],
buf
->
shape
[
i
]))
{
shapes_equal
=
false
;
break
;
}
}
}
Layout
reshaped
=
shapes_equal
?
rep_layout
.
value
()
:
rep_layout
.
value
()
->
Reshape
(
buf
->
shape
,
&
analyzer_
);
layout_map
.
Set
(
buf
,
reshaped
);
}
}
}
// Check that all local.fragment buffers have inferred layouts
// Check that all local.fragment buffers have inferred layouts
for
(
const
auto
&
[
buffer
,
_
]
:
use_list_
)
{
for
(
const
auto
&
[
buffer
,
_
]
:
use_list_
)
{
if
(
buffer
.
scope
()
==
"local.fragment"
)
{
if
(
buffer
.
scope
()
==
"local.fragment"
)
{
...
@@ -314,7 +398,13 @@ public:
...
@@ -314,7 +398,13 @@ public:
void
Collect
(
const
PrimFunc
&
f
)
{
void
Collect
(
const
PrimFunc
&
f
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
if
(
buffer_data_to_buffers_
.
count
(
buffer
->
data
))
{
auto
buffers
=
buffer_data_to_buffers_
[
buffer
->
data
];
buffers
.
push_back
(
buffer
);
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
buffers
);
}
else
{
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
{
buffer
});
}
}
}
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
ICHECK
(
target
.
defined
())
...
@@ -324,13 +414,25 @@ public:
...
@@ -324,13 +414,25 @@ public:
}
}
private:
private:
Map
<
Var
,
Buffer
>
GetBufferMap
()
const
{
Map
<
Var
,
Buffer
>
buffer_map
;
for
(
const
auto
&
[
var
,
buffers
]
:
buffer_data_to_buffers_
)
{
// Use the first buffer for each var
// TODO(lei): phaseout buffer_map in future.
if
(
!
buffers
.
empty
())
{
buffer_map
.
Set
(
var
,
buffers
[
0
]);
}
}
return
buffer_map
;
}
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
// Do not analysis the call node to the global function.
// Do not analysis the call node to the global function.
if
(
op
->
op
.
as
<
GlobalVarNode
>
())
if
(
op
->
op
.
as
<
GlobalVarNode
>
())
return
;
return
;
auto
p
=
ParseOperator
(
GetRef
<
Call
>
(
op
),
buffer_data_to_buffer_
);
auto
p
=
ParseOperator
(
tvm
::
ffi
::
GetRef
<
Call
>
(
op
),
GetBufferMap
()
);
if
(
p
.
defined
())
{
if
(
p
.
defined
())
{
for
(
const
auto
&
arg
:
op
->
args
)
{
for
(
const
auto
&
arg
:
op
->
args
)
{
if
(
auto
buffer
=
getBufferFromAccessPtr
(
arg
))
{
if
(
auto
buffer
=
getBufferFromAccessPtr
(
arg
))
{
...
@@ -381,7 +483,7 @@ private:
...
@@ -381,7 +483,7 @@ private:
}
}
// Add the tile operator to infer_list_
// Add the tile operator to infer_list_
infer_list_stmt_
.
push_back
(
GetRef
<
ObjectRef
>
(
op
));
infer_list_stmt_
.
push_back
(
tvm
::
ffi
::
GetRef
<
ObjectRef
>
(
op
));
infer_list_
.
push_back
(
std
::
move
(
p
));
infer_list_
.
push_back
(
std
::
move
(
p
));
}
}
}
}
...
@@ -394,12 +496,18 @@ private:
...
@@ -394,12 +496,18 @@ private:
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
auto
var_opt
=
call
->
args
[
1
].
as
<
Var
>
();
auto
var_opt
=
call
->
args
[
1
].
as
<
Var
>
();
if
(
!
var_opt
.
has_value
())
{
if
(
!
var_opt
.
has_value
())
{
D
LOG
(
WARNING
)
<<
"[getBufferFromAccessPtr] args[1] is not a Var, type: "
LOG
(
WARNING
)
<<
"[getBufferFromAccessPtr] args[1] is not a Var, type: "
<<
call
->
args
[
1
]
->
GetTypeKey
();
<<
call
->
args
[
1
]
->
GetTypeKey
();
return
std
::
nullopt
;
return
std
::
nullopt
;
}
}
const
auto
&
var
=
var_opt
.
value
();
const
auto
&
var
=
var_opt
.
value
();
return
buffer_data_to_buffer_
[
var
];
if
(
buffer_data_to_buffers_
.
count
(
var
))
{
const
auto
&
buffers
=
buffer_data_to_buffers_
[
var
];
if
(
!
buffers
.
empty
())
{
return
buffers
[
0
];
// Return the first buffer
}
}
return
std
::
nullopt
;
}
else
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
}
else
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
return
call
->
args
[
0
].
as
<
BufferLoadNode
>
()
->
buffer
;
return
call
->
args
[
0
].
as
<
BufferLoadNode
>
()
->
buffer
;
}
}
...
@@ -416,11 +524,11 @@ private:
...
@@ -416,11 +524,11 @@ private:
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
if
(
op
->
kind
==
ForKind
::
kParallel
)
{
if
(
op
->
kind
==
ForKind
::
kParallel
)
{
auto
infer
=
ParallelOp
(
GetRef
<
For
>
(
op
));
auto
infer
=
ParallelOp
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
));
for
(
const
auto
&
[
buffer
,
_
]
:
infer
->
GetIndiceMap
())
{
for
(
const
auto
&
[
buffer
,
_
]
:
infer
->
GetIndiceMap
())
{
addToUseList
(
buffer
);
addToUseList
(
buffer
);
}
}
infer_list_stmt_
.
push_back
(
GetRef
<
ObjectRef
>
(
op
));
infer_list_stmt_
.
push_back
(
tvm
::
ffi
::
GetRef
<
ObjectRef
>
(
op
));
infer_list_
.
push_back
(
std
::
move
(
infer
));
infer_list_
.
push_back
(
std
::
move
(
infer
));
thread_var_vec_
.
push_back
(
thread_var_
);
thread_var_vec_
.
push_back
(
thread_var_
);
if
(
thread_var_
.
defined
()
&&
if
(
thread_var_
.
defined
()
&&
...
@@ -442,21 +550,55 @@ private:
...
@@ -442,21 +550,55 @@ private:
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
auto
buffer
:
op
->
alloc_buffers
)
{
for
(
auto
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
if
(
buffer_data_to_buffers_
.
count
(
buffer
->
data
))
{
auto
buffers
=
buffer_data_to_buffers_
[
buffer
->
data
];
buffers
.
push_back
(
buffer
);
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
buffers
);
}
else
{
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
{
buffer
});
}
}
}
// First, visit the block body to collect all buffers from
// BufferLoad/BufferStore
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
// After visiting, apply layouts to all collected buffers
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
// Check if the layout map is Map<Var, Layout>
// Check if the layout map is Map<Var, Layout>
auto
map
=
auto
map
=
op
->
annotations
.
Get
(
attr
::
kLayoutMap
)
->
as
<
Map
<
Var
,
Layout
>>
().
value
();
op
->
annotations
.
Get
(
attr
::
kLayoutMap
)
->
as
<
Map
<
Var
,
Layout
>>
().
value
();
for
(
const
auto
&
[
var
,
layout
]
:
map
)
{
for
(
const
auto
&
[
var
,
layout
]
:
map
)
{
ICHECK
(
buffer_data_to_buffer_
.
count
(
var
))
ICHECK
(
buffer_data_to_buffer
s
_
.
count
(
var
))
<<
"buffer "
<<
var
<<
" is not found in the block"
;
<<
"buffer "
<<
var
<<
" is not found in the block"
;
auto
buffer
=
buffer_data_to_buffer_
[
var
];
const
auto
&
buffers
=
buffer_data_to_buffers_
[
var
];
ICHECK
(
StructuralEqual
()(
layout
->
InputShape
(),
buffer
->
shape
));
ICHECK
(
!
buffers
.
empty
())
<<
"buffer list for "
<<
var
<<
" is empty"
;
annotated_layout_map_
.
Set
(
buffer
,
layout
);
// Apply layout to all buffers associated with this var
for
(
const
auto
&
buffer
:
buffers
)
{
// Reshape the layout to match the buffer's shape
// Check if shapes are structurally equal
bool
shapes_equal
=
layout
->
InputShape
().
size
()
==
buffer
->
shape
.
size
();
if
(
shapes_equal
)
{
for
(
size_t
i
=
0
;
i
<
layout
->
InputShape
().
size
();
++
i
)
{
if
(
!
analyzer_
.
CanProveEqual
(
layout
->
InputShape
()[
i
],
buffer
->
shape
[
i
]))
{
shapes_equal
=
false
;
break
;
}
}
}
if
(
shapes_equal
)
{
annotated_layout_map_
.
Set
(
buffer
,
layout
);
}
else
{
auto
reshaped_layout
=
layout
->
Reshape
(
buffer
->
shape
,
&
analyzer_
);
annotated_layout_map_
.
Set
(
buffer
,
reshaped_layout
);
}
}
}
}
}
}
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
}
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
...
@@ -470,7 +612,67 @@ private:
...
@@ -470,7 +612,67 @@ private:
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
}
}
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
// Collect buffer from BufferLoad
if
(
op
->
buffer
.
defined
()
&&
op
->
buffer
->
data
.
defined
())
{
if
(
buffer_data_to_buffers_
.
count
(
op
->
buffer
->
data
))
{
// Check if this buffer is already in the list
auto
buffers
=
buffer_data_to_buffers_
[
op
->
buffer
->
data
];
bool
found
=
false
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
buf
.
same_as
(
op
->
buffer
))
{
found
=
true
;
break
;
}
}
if
(
!
found
)
{
buffers
.
push_back
(
op
->
buffer
);
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
buffers
);
DLOG
(
INFO
)
<<
"[LayoutInference] BufferLoad: added buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
else
{
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
{
op
->
buffer
});
DLOG
(
INFO
)
<<
"[LayoutInference] BufferLoad: new buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
// Collect buffer from BufferStore
if
(
op
->
buffer
.
defined
()
&&
op
->
buffer
->
data
.
defined
())
{
if
(
buffer_data_to_buffers_
.
count
(
op
->
buffer
->
data
))
{
// Check if this buffer is already in the list
auto
buffers
=
buffer_data_to_buffers_
[
op
->
buffer
->
data
];
bool
found
=
false
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
buf
.
same_as
(
op
->
buffer
))
{
found
=
true
;
break
;
}
}
if
(
!
found
)
{
buffers
.
push_back
(
op
->
buffer
);
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
buffers
);
DLOG
(
INFO
)
<<
"[LayoutInference] BufferStore: added buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
else
{
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
{
op
->
buffer
});
DLOG
(
INFO
)
<<
"[LayoutInference] BufferStore: new buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
}
Map
<
Var
,
Array
<
Buffer
>>
buffer_data_to_buffers_
;
std
::
vector
<
ObjectRef
>
infer_list_stmt_
;
std
::
vector
<
ObjectRef
>
infer_list_stmt_
;
std
::
vector
<
TileOperator
>
infer_list_
;
std
::
vector
<
TileOperator
>
infer_list_
;
std
::
unordered_map
<
Buffer
,
std
::
vector
<
int
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
std
::
unordered_map
<
Buffer
,
std
::
vector
<
int
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
...
@@ -513,12 +715,33 @@ private:
...
@@ -513,12 +715,33 @@ private:
if
(
infer_indices
.
empty
())
if
(
infer_indices
.
empty
())
continue
;
continue
;
// Union all infer_list_ indices that share the same
b
uffer
// Union all infer_list_ indices that share the same
B
uffer
object
int
first_idx
=
infer_indices
[
0
];
int
first_idx
=
infer_indices
[
0
];
for
(
size_t
i
=
1
;
i
<
infer_indices
.
size
();
i
++
)
{
for
(
size_t
i
=
1
;
i
<
infer_indices
.
size
();
i
++
)
{
uf
.
Union
(
first_idx
,
infer_indices
[
i
]);
uf
.
Union
(
first_idx
,
infer_indices
[
i
]);
}
}
}
}
// Additionally, union across buffers that share the same underlying
// buffer->data (Var). This handles cases like reshape where multiple
// Buffer objects alias the same storage.
for
(
const
auto
&
[
var
,
buffers
]
:
buffer_data_to_buffers_
)
{
std
::
vector
<
int
>
merged
;
for
(
const
auto
&
buf
:
buffers
)
{
auto
it
=
use_list_
.
find
(
buf
);
if
(
it
!=
use_list_
.
end
())
{
const
auto
&
vec
=
it
->
second
;
merged
.
insert
(
merged
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
}
if
(
merged
.
size
()
>
1
)
{
std
::
sort
(
merged
.
begin
(),
merged
.
end
());
merged
.
erase
(
std
::
unique
(
merged
.
begin
(),
merged
.
end
()),
merged
.
end
());
int
first
=
merged
[
0
];
for
(
size_t
i
=
1
;
i
<
merged
.
size
();
++
i
)
{
uf
.
Union
(
first
,
merged
[
i
]);
}
}
}
std
::
unordered_map
<
int
,
std
::
vector
<
int
>>
components
;
std
::
unordered_map
<
int
,
std
::
vector
<
int
>>
components
;
for
(
int
i
=
0
;
i
<
infer_list_
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
infer_list_
.
size
();
i
++
)
{
int
root
=
uf
.
Find
(
i
);
int
root
=
uf
.
Find
(
i
);
...
@@ -597,7 +820,9 @@ private:
...
@@ -597,7 +820,9 @@ private:
}
}
}
}
// Update the best plan if this one uses fewer registers
// Update the best plan if this one uses fewer registers
if
(
reg_num
<
min_reg_num
)
{
if
(
reg_num
<
min_reg_num
||
(
reg_num
==
min_reg_num
&&
attempt_infer_root
<
min_reg_num_infer_root
))
{
best_infer_list
=
best_infer_list
=
BackupInferList
();
// Use backup to avoid moving out infer_list_
BackupInferList
();
// Use backup to avoid moving out infer_list_
best_layout_map
=
tmp_layout_map
;
best_layout_map
=
tmp_layout_map
;
...
@@ -711,8 +936,8 @@ private:
...
@@ -711,8 +936,8 @@ private:
.
value
();
.
value
();
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
if
(
result_
.
for_map
.
count
(
GetRef
<
For
>
(
op
)))
{
if
(
result_
.
for_map
.
count
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
)))
{
auto
root
=
GetRef
<
For
>
(
op
);
auto
root
=
tvm
::
ffi
::
GetRef
<
For
>
(
op
);
// This check is a workaround to support T.Parallel for local buffers.
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// For example:
// for i in T.Parallel(1024):
// for i in T.Parallel(1024):
...
@@ -787,7 +1012,18 @@ private:
...
@@ -787,7 +1012,18 @@ private:
}
}
});
});
if
(
has_non_local
&&
!
has_reducer
)
{
// If a cast operation exists, vectorization may still be required
bool
has_cast_operations
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
// Check if this is a non-reducer store with Cast operation
if
(
store
->
value
.
as
<
CastNode
>
())
{
has_cast_operations
=
true
;
}
}
});
if
((
has_non_local
||
has_cast_operations
)
&&
!
has_reducer
)
{
for_node
=
VectorizeLoop
(
for_node
);
for_node
=
VectorizeLoop
(
for_node
);
}
}
...
@@ -831,10 +1067,10 @@ tvm::transform::Pass LayoutInference() {
...
@@ -831,10 +1067,10 @@ tvm::transform::Pass LayoutInference() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LayoutInference"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LayoutInference"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LayoutInference"
,
LayoutInference
);
refl
::
GlobalDef
().
def
(
"tl.transform.LayoutInference"
,
LayoutInference
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/layout_reducer.cc
View file @
bbbf4207
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "../layout/layout.h"
#include "../layout/layout.h"
#include "../op/fill.h"
#include "../op/fill.h"
#include "../op/finalize_reducer.h"
#include "../op/finalize_reducer.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"
#include "layout_reducer.h"
...
@@ -275,17 +276,34 @@ private:
...
@@ -275,17 +276,34 @@ private:
auto
op
=
op_ref
.
CopyOnWrite
();
auto
op
=
op_ref
.
CopyOnWrite
();
if
(
op
->
op
.
same_as
(
Fill
::
Get
()))
{
if
(
op
->
op
.
same_as
(
Fill
::
Get
()))
{
ICHECK
(
!
op
->
args
.
empty
());
ICHECK
(
!
op
->
args
.
empty
());
if
(
auto
arg0_call
=
op
->
args
[
0
].
as
<
Call
>
();
if
(
auto
arg0_call
=
op
->
args
[
0
].
as
<
Call
>
())
{
arg0_call
&&
// Case 1: tl.region(...) — extract buffer var from its first arg
arg0_call
.
value
()
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
if
(
arg0_call
.
value
()
->
op
.
same_as
(
RegionOp
::
Get
()))
{
ICHECK
(
arg0_call
.
value
()
->
args
.
size
()
>
1
);
ICHECK
(
!
arg0_call
.
value
()
->
args
.
empty
());
if
(
auto
var
=
arg0_call
.
value
()
->
args
[
1
].
as
<
Var
>
();
if
(
auto
bl
=
arg0_call
.
value
()
->
args
[
0
].
as
<
BufferLoadNode
>
())
{
var
&&
reducer_info_map_
.
count
(
var
.
value
()))
{
Var
var
=
bl
->
buffer
->
data
;
ICHECK
(
inside_reducer_range_
.
count
(
var
.
value
())
==
0
)
if
(
reducer_info_map_
.
count
(
var
))
{
<<
"T.fill on reducer must be enclosed with a T.finalize_reducer "
ICHECK
(
inside_reducer_range_
.
count
(
var
)
==
0
)
"before next."
;
<<
"T.fill on reducer must be enclosed with a "
inside_reducer_range_
.
Set
(
var
.
value
(),
"T.finalize_reducer "
reducer_info_map_
.
Get
(
var
.
value
()).
value
());
"before next."
;
inside_reducer_range_
.
Set
(
var
,
reducer_info_map_
.
Get
(
var
).
value
());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else
if
(
arg0_call
.
value
()
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
ICHECK
(
arg0_call
.
value
()
->
args
.
size
()
>
1
);
if
(
auto
var
=
arg0_call
.
value
()
->
args
[
1
].
as
<
Var
>
();
var
&&
reducer_info_map_
.
count
(
var
.
value
()))
{
ICHECK
(
inside_reducer_range_
.
count
(
var
.
value
())
==
0
)
<<
"T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next."
;
inside_reducer_range_
.
Set
(
var
.
value
(),
reducer_info_map_
.
Get
(
var
.
value
()).
value
());
}
}
}
}
}
}
else
if
(
op
->
op
.
same_as
(
FinalizeReducerOp
::
Get
()))
{
}
else
if
(
op
->
op
.
same_as
(
FinalizeReducerOp
::
Get
()))
{
...
@@ -362,10 +380,10 @@ tvm::transform::Pass LayoutReducer() {
...
@@ -362,10 +380,10 @@ tvm::transform::Pass LayoutReducer() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LayoutReducer"
,
{});
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LayoutReducer"
,
{});
}
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LayoutReducer"
,
LayoutReducer
);
refl
::
GlobalDef
().
def
(
"tl.transform.LayoutReducer"
,
LayoutReducer
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/layout_reducer.h
View file @
bbbf4207
...
@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object {
...
@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object {
ReducerInfoNode
()
=
default
;
ReducerInfoNode
()
=
default
;
ReducerInfoNode
(
const
String
&
op_str
,
const
String
&
rep_str
);
ReducerInfoNode
(
const
String
&
op_str
,
const
String
&
rep_str
);
static
constexpr
const
char
*
_type_key
=
"tl.ReducerInfo"
;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.ReducerInfo"
,
ReducerInfoNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
ReducerInfoNode
,
Object
);
};
};
struct
ReducerInfo
:
ObjectRef
{
struct
ReducerInfo
:
ObjectRef
{
public:
public:
TVM_DLL
ReducerInfo
(
const
String
&
op_str
,
const
String
&
rep_str
)
{
TVM_DLL
ReducerInfo
(
const
String
&
op_str
,
const
String
&
rep_str
)
{
data_
=
make_object
<
ReducerInfoNode
>
(
op_str
,
rep_str
);
data_
=
tvm
::
ffi
::
make_object
<
ReducerInfoNode
>
(
op_str
,
rep_str
);
}
}
TVM_DEFINE_OBJECT_REF_METHODS
(
ReducerInfo
,
ObjectRef
,
ReducerInfoNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
ReducerInfo
,
ObjectRef
,
ReducerInfoNode
);
};
};
namespace
attr
{
namespace
attr
{
...
...
src/transform/legalize_negative_index.cc
0 → 100644
View file @
bbbf4207
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <vector>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
arith
::
IRVisitorWithAnalyzer
;
enum
class
IndexSignState
{
kNonNegative
,
kNegative
,
kUnknown
};
class
NegativeIndexAnalyzer
:
public
IRVisitorWithAnalyzer
{
public:
explicit
NegativeIndexAnalyzer
(
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
*
result
)
:
result_
(
result
)
{}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
);
std
::
vector
<
IndexSignState
>
states
;
states
.
reserve
(
op
->
indices
.
size
());
bool
needs_record
=
false
;
for
(
size_t
i
=
0
;
i
<
op
->
indices
.
size
();
++
i
)
{
PrimExpr
simplified
=
analyzer_
.
Simplify
(
op
->
indices
[
i
]);
// Handle scalar indices with the standard analyzer
if
(
simplified
.
dtype
().
lanes
()
==
1
)
{
if
(
analyzer_
.
CanProve
(
simplified
>=
0
))
{
states
.
push_back
(
IndexSignState
::
kNonNegative
);
continue
;
}
if
(
analyzer_
.
CanProve
(
simplified
<
0
))
{
states
.
push_back
(
IndexSignState
::
kNegative
);
needs_record
=
true
;
continue
;
}
states
.
push_back
(
IndexSignState
::
kUnknown
);
needs_record
=
true
;
DLOG
(
WARNING
)
<<
"LegalizeNegativeIndex: cannot prove non-negative index "
<<
simplified
<<
" for buffer "
<<
load
->
buffer
->
name
<<
" (axis "
<<
i
<<
")."
;
continue
;
}
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
IndexSignState
vec_state
=
IndexSignState
::
kUnknown
;
if
(
const
auto
*
ramp
=
simplified
.
as
<
RampNode
>
())
{
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto
base_bound
=
analyzer_
.
const_int_bound
(
ramp
->
base
);
auto
stride_bound
=
analyzer_
.
const_int_bound
(
ramp
->
stride
);
int
lanes
=
*
as_const_int
(
ramp
->
lanes
);
int64_t
base_min
=
base_bound
->
min_value
;
int64_t
base_max
=
base_bound
->
max_value
;
int64_t
s_min
=
stride_bound
->
min_value
;
int64_t
s_max
=
stride_bound
->
max_value
;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t
lower
=
base_min
;
if
(
s_min
<
0
)
lower
+=
s_min
*
(
lanes
-
1
);
int64_t
upper
=
base_max
;
if
(
s_max
>
0
)
upper
+=
s_max
*
(
lanes
-
1
);
if
(
lower
>=
0
)
{
vec_state
=
IndexSignState
::
kNonNegative
;
}
else
if
(
upper
<
0
)
{
vec_state
=
IndexSignState
::
kNegative
;
}
else
{
vec_state
=
IndexSignState
::
kUnknown
;
}
}
else
if
(
const
auto
*
bc
=
simplified
.
as
<
BroadcastNode
>
())
{
auto
v
=
analyzer_
.
Simplify
(
bc
->
value
);
if
(
analyzer_
.
CanProve
(
v
>=
0
))
{
vec_state
=
IndexSignState
::
kNonNegative
;
}
else
if
(
analyzer_
.
CanProve
(
v
<
0
))
{
vec_state
=
IndexSignState
::
kNegative
;
}
else
{
// Try const bound if proof unavailable
auto
vb
=
analyzer_
.
const_int_bound
(
v
);
if
(
vb
->
min_value
>=
0
)
{
vec_state
=
IndexSignState
::
kNonNegative
;
}
else
if
(
vb
->
max_value
<
0
)
{
vec_state
=
IndexSignState
::
kNegative
;
}
else
{
vec_state
=
IndexSignState
::
kUnknown
;
}
}
}
if
(
vec_state
==
IndexSignState
::
kNonNegative
)
{
states
.
push_back
(
IndexSignState
::
kNonNegative
);
continue
;
}
if
(
vec_state
==
IndexSignState
::
kNegative
)
{
states
.
push_back
(
IndexSignState
::
kNegative
);
needs_record
=
true
;
continue
;
}
states
.
push_back
(
IndexSignState
::
kUnknown
);
needs_record
=
true
;
DLOG
(
WARNING
)
<<
"LegalizeNegativeIndex: cannot prove non-negative index "
<<
simplified
<<
" for buffer "
<<
load
->
buffer
->
name
<<
" (axis "
<<
i
<<
")."
;
}
if
(
needs_record
)
{
(
*
result_
)[
op
]
=
std
::
move
(
states
);
}
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
}
private:
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
*
result_
;
};
class
NegativeIndexRewriter
:
public
arith
::
IRMutatorWithAnalyzer
{
public:
static
PrimFunc
Apply
(
PrimFunc
func
,
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states
)
{
arith
::
Analyzer
analyzer
;
NegativeIndexRewriter
rewriter
(
&
analyzer
,
states
);
if
(
!
func
->
body
.
defined
())
{
return
func
;
}
PrimFuncNode
*
func_node
=
func
.
CopyOnWrite
();
func_node
->
body
=
rewriter
.
VisitStmt
(
func_node
->
body
);
return
func
;
}
private:
NegativeIndexRewriter
(
arith
::
Analyzer
*
analyzer
,
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
),
states_
(
states
)
{}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
arith
::
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
auto
it
=
states_
.
find
(
op
);
if
(
it
==
states_
.
end
())
{
return
load
;
}
auto
indices
=
load
->
indices
;
bool
changed
=
false
;
const
auto
&
state_vector
=
it
->
second
;
ICHECK_EQ
(
state_vector
.
size
(),
indices
.
size
())
<<
"State vector size mismatch for buffer load "
<<
load
->
buffer
->
name
;
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
if
(
state_vector
[
i
]
!=
IndexSignState
::
kNegative
)
{
continue
;
}
PrimExpr
extent
=
load
->
buffer
->
shape
[
i
];
indices
.
Set
(
i
,
analyzer_
->
Simplify
(
extent
+
indices
[
i
]));
changed
=
true
;
}
if
(
!
changed
)
{
return
load
;
}
return
BufferLoad
(
load
->
buffer
,
indices
);
}
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states_
;
};
PrimFunc
LegalizeNegativeIndex
(
PrimFunc
func
)
{
if
(
!
func
->
body
.
defined
())
{
return
func
;
}
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
states
;
NegativeIndexAnalyzer
analyzer
(
&
states
);
analyzer
(
func
->
body
);
if
(
states
.
empty
())
{
return
func
;
}
return
NegativeIndexRewriter
::
Apply
(
std
::
move
(
func
),
states
);
}
tvm
::
transform
::
Pass
LegalizeNegativeIndexPass
()
{
using
namespace
tir
::
transform
;
auto
pass_func
=
[](
PrimFunc
f
,
const
IRModule
&
,
PassContext
)
{
return
LegalizeNegativeIndex
(
std
::
move
(
f
));
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LegalizeNegativeIndex"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeNegativeIndex"
,
LegalizeNegativeIndexPass
);
}
}
// namespace tl
}
// namespace tvm
src/transform/legalize_safe_memory_access.cc
View file @
bbbf4207
...
@@ -38,7 +38,7 @@ private:
...
@@ -38,7 +38,7 @@ private:
StmtVisitor
::
VisitStmt
(
op
->
body
);
StmtVisitor
::
VisitStmt
(
op
->
body
);
if
(
!
has_child_for_
)
{
if
(
!
has_child_for_
)
{
leaf_for_nodes
.
push_back
(
GetRef
<
For
>
(
op
));
leaf_for_nodes
.
push_back
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
));
}
}
parent_has_child_for_
=
parent_has_child_for
;
parent_has_child_for_
=
parent_has_child_for
;
...
@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
...
@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
}
}
// Register the pass globally so it can be used in the compilation pipeline
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeSafeMemoryAccess"
,
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeSafeMemoryAccess"
,
LegalizeSafeMemoryAccess
);
LegalizeSafeMemoryAccess
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/transform/legalize_vectorized_loop.cc
View file @
bbbf4207
...
@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
...
@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
}
}
// Register the pass globally so it can be used in the compilation pipeline
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeVectorizedLoop"
,
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeVectorizedLoop"
,
LegalizeVectorizedLoop
);
LegalizeVectorizedLoop
);
}
);
}
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment