Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
758 additions
and
161 deletions
+758
-161
src/transform/atomicadd_vectorize.cc
src/transform/atomicadd_vectorize.cc
+1
-2
src/transform/cluster_planning.cc
src/transform/cluster_planning.cc
+7
-4
src/transform/common/loop_parallel_transform_utils.h
src/transform/common/loop_parallel_transform_utils.h
+2
-2
src/transform/common/loop_vectorization_utils.h
src/transform/common/loop_vectorization_utils.h
+28
-28
src/transform/config_index_bitwidth.cc
src/transform/config_index_bitwidth.cc
+7
-7
src/transform/eliminate_storage_sync_for_mbarrier.cc
src/transform/eliminate_storage_sync_for_mbarrier.cc
+4
-6
src/transform/flatten_buffer.cc
src/transform/flatten_buffer.cc
+7
-7
src/transform/frontend_legalize.cc
src/transform/frontend_legalize.cc
+2
-2
src/transform/if_stmt_binding.cc
src/transform/if_stmt_binding.cc
+3
-3
src/transform/inject_assumes.cc
src/transform/inject_assumes.cc
+2
-2
src/transform/inject_fence_proxy.cc
src/transform/inject_fence_proxy.cc
+4
-3
src/transform/inject_pipeline.cc
src/transform/inject_pipeline.cc
+52
-16
src/transform/inject_ptx_async_copy.cc
src/transform/inject_ptx_async_copy.cc
+2
-2
src/transform/inject_tma_barrier.cc
src/transform/inject_tma_barrier.cc
+100
-27
src/transform/layout_inference.cc
src/transform/layout_inference.cc
+264
-28
src/transform/layout_reducer.cc
src/transform/layout_reducer.cc
+31
-13
src/transform/layout_reducer.h
src/transform/layout_reducer.h
+4
-4
src/transform/legalize_negative_index.cc
src/transform/legalize_negative_index.cc
+233
-0
src/transform/legalize_safe_memory_access.cc
src/transform/legalize_safe_memory_access.cc
+3
-3
src/transform/legalize_vectorized_loop.cc
src/transform/legalize_vectorized_loop.cc
+2
-2
No files found.
src/transform/atomicadd_vectorize.cc
View file @
bbbf4207
...
...
@@ -249,7 +249,6 @@ private:
new_args
.
push_back
(
dst_node
);
new_args
.
push_back
(
value_node
);
}
new_args
.
push_back
(
memory_order
);
Call
new_call
=
...
...
src/transform/cluster_planning.cc
View file @
bbbf4207
...
...
@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
namespace
tvm
{
namespace
tir
{
...
...
@@ -66,7 +68,8 @@ public:
}
if
(
mem_reuse_max
>
0
)
{
std
::
string
tag_str
=
cluster_tag
;
// Convert to std::string
std
::
string
tag_str
=
static_cast
<
std
::
string
>
(
cluster_tag
);
// Convert to std::string
if
(
tag_str
.
rfind
(
"blockIdx"
,
0
)
==
0
)
{
// starts with "blockIdx"
tag_str
=
"clusterIdx"
+
tag_str
.
substr
(
strlen
(
"blockIdx"
));
...
...
@@ -74,7 +77,7 @@ public:
// Unexpected format — maybe just prefix
tag_str
=
"clusterIdx"
+
tag_str
;
}
cluster_tag
=
tvm
::
ffi
::
String
(
tag_str
);
// Convert back
cluster_tag
=
String
(
tag_str
);
// Convert back
return
WithAttr
(
f
,
cluster_tag
,
Integer
(
cluster_size_
));
}
else
{
return
f
;
...
...
@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ClusterPlanning"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.ClusterPlanning"
,
ClusterPlanning
);
}
);
}
}
// namespace transform
}
// namespace tir
...
...
src/transform/common/loop_parallel_transform_utils.h
View file @
bbbf4207
...
...
@@ -41,7 +41,7 @@ public:
return
StmtMutator
::
VisitStmt_
(
op
);
// Collect loop variables and ranges
auto
for_node
=
GetRef
<
For
>
(
op
);
auto
for_node
=
tvm
::
ffi
::
GetRef
<
For
>
(
op
);
Array
<
Var
>
loop_vars
;
Array
<
PrimExpr
>
loop_extents
;
Stmt
body
=
op
->
body
;
...
...
@@ -81,7 +81,7 @@ public:
// post order visit the index
PostOrderVisit
(
index
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
VarNode
*
v
=
obj
.
as
<
VarNode
>
())
{
used_vars
.
insert
(
GetRef
<
Var
>
(
v
));
used_vars
.
insert
(
tvm
::
ffi
::
GetRef
<
Var
>
(
v
));
}
});
if
(
used_vars
.
empty
())
{
...
...
src/transform/common/loop_vectorization_utils.h
View file @
bbbf4207
...
...
@@ -211,7 +211,7 @@ public:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
bool
is_vec_a
=
a
.
dtype
().
is_scalable_or_fixed_length_vector
();
bool
is_vec_b
=
b
.
dtype
().
is_scalable_or_fixed_length_vector
();
...
...
@@ -265,7 +265,7 @@ public:
PrimExpr
VisitExpr_
(
const
NotNode
*
op
)
final
{
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
!
(
a
);
}
...
...
@@ -306,10 +306,10 @@ public:
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
if
(
value
.
same_as
(
op
->
value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Broadcast
(
op
->
value
,
op
->
lanes
);
}
...
...
@@ -321,7 +321,7 @@ public:
PrimExpr
f
=
this
->
VisitExpr
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
cond_lanes
=
cond
.
dtype
().
get_lanes_or_vscale_factor
();
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
@@ -339,7 +339,7 @@ public:
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
if
(
value
.
dtype
().
is_scalable_vector
())
{
return
Cast
(
op
->
dtype
.
with_scalable_vscale_factor
(
...
...
@@ -352,20 +352,20 @@ public:
}
PrimExpr
VisitExpr_
(
const
FloatImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
StringImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
// Variable
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
Var
var
=
GetRef
<
Var
>
(
op
);
Var
var
=
tvm
::
ffi
::
GetRef
<
Var
>
(
op
);
if
(
var
.
same_as
(
var_
))
{
return
ramp_
;
...
...
@@ -382,13 +382,13 @@ public:
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
cond
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
t
=
this
->
VisitExpr
(
op
->
args
[
1
]);
PrimExpr
f
=
this
->
VisitExpr
(
op
->
args
[
2
]);
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
t
.
same_as
(
op
->
args
[
1
])
&&
f
.
same_as
(
op
->
args
[
2
]))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
int
f_lanes
=
f
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
@@ -410,7 +410,7 @@ public:
ICHECK
(
op
->
op
.
same_as
(
builtin
::
reinterpret
()));
PrimExpr
value
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
value
.
same_as
(
op
->
args
[
0
]))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
lanes
=
value
.
dtype
().
get_lanes_or_vscale_factor
();
if
(
value
.
dtype
().
is_scalable_vector
())
{
...
...
@@ -455,12 +455,12 @@ public:
auto
new_arg
=
this
->
VisitExpr
(
arg
);
if
(
new_arg
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
new_args
.
push_back
(
new_arg
);
}
if
(
op
->
args
.
same_as
(
new_args
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
...
...
@@ -469,7 +469,7 @@ public:
Array
<
PrimExpr
>
new_args
=
MutateArray
(
op
->
args
,
&
lane
);
// normal code path.
if
(
op
->
args
.
same_as
(
new_args
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Call
(
op
->
dtype
.
with_lanes
(
lane
),
op
->
op
,
new_args
);
}
...
...
@@ -477,7 +477,7 @@ public:
}
// BufferLoad
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
GetRef
<
BufferLoad
>
(
op
);
auto
load
=
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
...
...
@@ -514,7 +514,7 @@ public:
let_binding_
[
op
->
var
]
=
op
->
var
;
PrimExpr
body
=
this
->
VisitExpr
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Let
(
op
->
var
,
value
,
body
);
}
...
...
@@ -522,7 +522,7 @@ public:
}
// BufferStore
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
GetRef
<
BufferStore
>
(
op
);
auto
store
=
tvm
::
ffi
::
GetRef
<
BufferStore
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
...
...
@@ -585,11 +585,11 @@ public:
ICHECK
(
!
op
->
extent
.
dtype
().
is_scalable_or_fixed_length_vector
());
PrimExpr
extent
=
this
->
VisitExpr
(
op
->
extent
);
if
(
extent
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
if
(
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
return
For
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
);
...
...
@@ -600,7 +600,7 @@ public:
ICHECK
(
!
op
->
condition
.
dtype
().
is_scalable_or_fixed_length_vector
());
PrimExpr
condition
=
this
->
VisitExpr
(
op
->
condition
);
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
Stmt
then_case
=
this
->
VisitStmt
(
op
->
then_case
);
Optional
<
Stmt
>
else_case
=
std
::
nullopt
;
...
...
@@ -609,7 +609,7 @@ public:
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
return
IfThenElse
(
condition
,
then_case
,
else_case
);
}
...
...
@@ -634,7 +634,7 @@ public:
let_binding_
[
op
->
var
]
=
op
->
var
;
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
return
LetStmt
(
op
->
var
,
value
,
body
);
}
...
...
@@ -647,7 +647,7 @@ public:
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
// Mutate the extents
...
...
@@ -657,7 +657,7 @@ public:
if
(
new_ext
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
extents
.
push_back
(
new_ext
);
}
...
...
@@ -738,7 +738,7 @@ private:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
@@ -754,7 +754,7 @@ private:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
src/transform/config_index_bitwidth.cc
View file @
bbbf4207
...
...
@@ -38,7 +38,7 @@ protected:
if
(
is_enabled_
&&
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
IntImm
(
DataType
::
Int
(
_index_bitwidth_
),
op
->
value
);
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
...
...
@@ -88,23 +88,23 @@ private:
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
GetRef
<
Var
>
(
op
));
return
cast
(
DataType
::
Int
(
64
),
tvm
::
ffi
::
GetRef
<
Var
>
(
op
));
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
IntImm
(
DataType
::
Int
(
64
),
op
->
value
);
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
op
->
value
);
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
...
...
@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ConfigIndexBitwidth"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.ConfigIndexBitwidth"
,
ConfigIndexBitwidth
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/eliminate_storage_sync_for_mbarrier.cc
View file @
bbbf4207
...
...
@@ -35,9 +35,7 @@ public:
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
"thread_extent"
)
{
const
VarNode
*
var
=
nullptr
;
if
(
op
->
node
->
IsInstance
<
VarNode
>
())
{
var
=
op
->
node
.
as
<
VarNode
>
();
if
(
const
auto
*
var
=
op
->
node
.
as
<
VarNode
>
())
{
if
(
var
->
name_hint
==
"threadIdx.x"
)
{
thread_extent_
=
op
;
}
...
...
@@ -82,7 +80,7 @@ public:
}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
PostOrderVisit
(
GetRef
<
For
>
(
op
),
[
&
](
const
ObjectRef
&
node
)
{
PostOrderVisit
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
),
[
&
](
const
ObjectRef
&
node
)
{
if
(
const
auto
*
call
=
node
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
create_list_of_mbarrier
())
||
call
->
op
.
same_as
(
mbarrier_wait_parity
())
||
...
...
@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.EliminateStorageSyncForMBarrier"
,
EliminateStorageSyncForMBarrier
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/flatten_buffer.cc
View file @
bbbf4207
...
...
@@ -75,23 +75,23 @@ private:
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
GetRef
<
Var
>
(
op
));
return
cast
(
DataType
::
Int
(
64
),
tvm
::
ffi
::
GetRef
<
Var
>
(
op
));
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
IntImm
(
DataType
::
Int
(
64
),
op
->
value
);
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
if
(
op
->
dtype
.
is_int
()
&&
op
->
dtype
.
bits
()
<
64
)
{
return
cast
(
DataType
::
Int
(
64
),
op
->
value
);
}
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
...
...
@@ -115,7 +115,7 @@ private:
<<
"All MatchBufferRegion should be removed in "
"tir.transform.LowerMatchBuffer."
;
Block
block
=
GetRef
<
Block
>
(
op
);
Block
block
=
tvm
::
ffi
::
GetRef
<
Block
>
(
op
);
Array
<
Buffer
>
alloc_buffers
=
op
->
alloc_buffers
;
alloc_buffers
.
MutateByApply
(
...
...
@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.FlattenBuffer"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.FlattenBuffer"
,
FlattenBuffer
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/frontend_legalize.cc
View file @
bbbf4207
...
...
@@ -89,10 +89,10 @@ Pass LetInline() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LetInline"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LetInline"
,
LetInline
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/if_stmt_binding.cc
View file @
bbbf4207
...
...
@@ -33,7 +33,7 @@ private:
auto
then_case
=
VisitStmt
(
op
->
then_case
);
Optional
<
Stmt
>
else_case
=
op
->
else_case
;
if
(
else_case
.
defined
())
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
ICHECK
(
then_case
.
defined
())
<<
"then_case must be defined"
;
ICHECK
(
!
else_case
.
defined
())
<<
"else_case must be undefined"
;
...
...
@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.IfStmtBinding"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.IfStmtBinding"
,
IfStmtBinding
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/inject_assumes.cc
View file @
bbbf4207
...
...
@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectAssumes"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectAssumes"
,
InjectAssumes
);
}
);
}
}
// namespace tvm::tl
src/transform/inject_fence_proxy.cc
View file @
bbbf4207
...
...
@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) {
return
false
;
}
return
call
->
op
.
same_as
(
ptx_ldmatrix
())
||
call
->
op
.
same_as
(
ptx_stmatrix
())
||
call
->
op
.
same_as
(
initialize_descriptor
());
call
->
op
.
same_as
(
initialize_wgmma_descriptor
())
||
call
->
op
.
same_as
(
initialize_tcgen05_descriptor
());
}
ProxyKind
ProxyFromAttrValue
(
const
ObjectRef
&
value
)
{
...
...
@@ -319,10 +320,10 @@ tvm::transform::Pass InjectFenceProxy() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectFenceProxy"
,
InjectFenceProxy
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/inject_pipeline.cc
View file @
bbbf4207
...
...
@@ -37,9 +37,14 @@
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
namespace
software_pipeline
{
struct
LetWrapper
{
Var
var
;
PrimExpr
value
;
};
/*!
* \brief Create a block and infer the access region with the given body.
*
...
...
@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator {
public:
PipelineRewriter
(
Map
<
Var
,
Buffer
>
buffer_data_to_buffer
,
const
Array
<
Buffer
>
&
pipeline_allocs
,
const
For
&
pipeline_loop
,
const
PipelineInfo
&
pipeline_info
)
const
For
&
pipeline_loop
,
const
PipelineInfo
&
pipeline_info
,
const
std
::
vector
<
LetWrapper
>
&
loop_var_let_wrappers
)
:
buffer_data_to_buffer_
(
std
::
move
(
buffer_data_to_buffer
)),
pipeline_allocs_
(
pipeline_allocs
),
pipeline_loop_
(
pipeline_loop
),
pipeline_info_
(
pipeline_info
)
{}
pipeline_info_
(
pipeline_info
),
loop_var_let_wrappers_
(
loop_var_let_wrappers
)
{}
Stmt
BuildPipeline
()
{
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
...
...
@@ -459,7 +466,8 @@ private:
* \return The resized buffer.
*/
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
ObjectPtr
<
BufferNode
>
new_buffer
=
tvm
::
ffi
::
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
if
(
!
new_buffer
->
strides
.
empty
())
{
ICHECK
(
new_buffer
->
strides
.
size
()
+
1
==
new_buffer
->
shape
.
size
());
...
...
@@ -676,6 +684,20 @@ private:
new_block
=
Downcast
<
Block
>
(
Substitute
(
new_block
,
{{
pipeline_loop_
->
loop_var
,
normalized_access_index
}}));
// If there were Let-wrappers outside the original pipeline body that
// depended on the pipeline loop var, push them into each rewritten
// block with the correct per-block substitution.
if
(
!
loop_var_let_wrappers_
.
empty
())
{
BlockNode
*
n
=
new_block
.
CopyOnWrite
();
Stmt
inner
=
n
->
body
;
for
(
const
auto
&
lw
:
loop_var_let_wrappers_
)
{
PrimExpr
substituted
=
Substitute
(
lw
.
value
,
{{
pipeline_loop_
->
loop_var
,
normalized_access_index
}});
inner
=
LetStmt
(
lw
.
var
,
substituted
,
inner
);
}
n
->
body
=
inner
;
}
if
(
pipeline_info_
[
block
].
async
)
{
auto
&
local_state
=
async_states_local
[
stage
];
local_state
.
producer_head
=
normalized_access_index
;
...
...
@@ -737,6 +759,7 @@ private:
Map
<
Buffer
,
Buffer
>
buffer_remap_
;
Array
<
Block
>
ordered_stmts_
;
std
::
map
<
int
,
AsyncStateGlobal
>
async_states
;
std
::
vector
<
LetWrapper
>
loop_var_let_wrappers_
;
};
/*!
...
...
@@ -864,8 +887,9 @@ private:
const
SeqStmtNode
*
pipeline_body_seq
=
nullptr
;
std
::
vector
<
std
::
function
<
Stmt
(
Stmt
)
>>
rewrap_fns
;
std
::
vector
<
LetWrapper
>
loop_var_let_wrappers
;
auto
append_attr_wrapper
=
[
&
rewrap_fns
](
const
AttrStmtNode
*
attr
)
{
ObjectRef
node
=
attr
->
node
;
Any
node
=
attr
->
node
;
String
attr_key
=
attr
->
attr_key
;
PrimExpr
value
=
attr
->
value
;
Span
span
=
attr
->
span
;
...
...
@@ -896,6 +920,16 @@ private:
continue
;
}
if
(
const
auto
*
let_stmt
=
current
.
as
<
LetStmtNode
>
())
{
// If this Let value uses the pipeline loop var, record it and push
// inside each rewritten block later so the loop var can be
// substituted with the correct per-iteration index. Otherwise, keep
// it as a normal wrapper.
bool
uses_loop_var
=
UsesVar
(
let_stmt
->
value
,
[
v
=
op
->
loop_var
.
get
()](
const
VarNode
*
vn
)
{
return
vn
==
v
;
});
if
(
uses_loop_var
)
{
loop_var_let_wrappers
.
push_back
({
let_stmt
->
var
,
let_stmt
->
value
});
}
else
{
Var
var
=
let_stmt
->
var
;
PrimExpr
value
=
let_stmt
->
value
;
Span
span
=
let_stmt
->
span
;
...
...
@@ -904,6 +938,7 @@ private:
span
](
Stmt
body
)
->
Stmt
{
return
LetStmt
(
var
,
value
,
body
,
span
);
});
}
current
=
let_stmt
->
body
;
continue
;
}
...
...
@@ -981,7 +1016,8 @@ private:
// Step 4: Rewrite the pipeline body.
Stmt
pipeline
=
PipelineRewriter
(
buffer_data_to_buffer_
,
pipeline_allocs
,
GetRef
<
For
>
(
op
),
pipeline_info
)
tvm
::
ffi
::
GetRef
<
For
>
(
op
),
pipeline_info
,
loop_var_let_wrappers
)
.
BuildPipeline
();
auto
apply_wrappers
=
[
&
](
Stmt
stmt
)
{
for
(
auto
it
=
rewrap_fns
.
rbegin
();
it
!=
rewrap_fns
.
rend
();
++
it
)
{
...
...
@@ -1072,11 +1108,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectSoftwarePipeline"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectSoftwarePipeline"
,
InjectSoftwarePipeline
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/inject_ptx_async_copy.cc
View file @
bbbf4207
...
...
@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectPTXAsyncCopy"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectPTXAsyncCopy"
,
InjectPTXAsyncCopy
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/inject_tma_barrier.cc
View file @
bbbf4207
...
...
@@ -204,9 +204,9 @@ private:
void
VisitStmt_
(
const
EvaluateNode
*
op
)
final
{
if
(
const
auto
*
call
=
op
->
value
.
as
<
CallNode
>
())
{
if
(
call
->
op
.
same_as
(
tma_load
())
||
call
->
op
.
same_as
(
tma_load_im2col
()))
{
pending_tma_ops_
.
push_back
(
GetRef
<
Call
>
(
call
));
pending_tma_ops_
.
push_back
(
tvm
::
ffi
::
GetRef
<
Call
>
(
call
));
}
else
if
(
call
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
pending_tma_ops_
.
push_back
(
GetRef
<
Call
>
(
call
));
pending_tma_ops_
.
push_back
(
tvm
::
ffi
::
GetRef
<
Call
>
(
call
));
}
else
if
(
call
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
PrimExpr
barrier_id
=
call
->
args
[
0
];
for
(
const
auto
&
tma_call
:
pending_tma_ops_
)
{
...
...
@@ -295,13 +295,15 @@ public:
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
PrimExpr
e
=
tma_op_to_barrier_id_
[
GetRef
<
Call
>
(
op
)].
as
<
CallNode
>
()
->
args
[
0
];
auto
call_ref
=
tvm
::
ffi
::
GetRef
<
Call
>
(
op
);
if
(
tma_op_to_barrier_id_
.
count
(
call_ref
))
{
PrimExpr
e
=
tma_op_to_barrier_id_
[
call_ref
].
as
<
CallNode
>
()
->
args
[
0
];
auto
int_set
=
arith
::
EvalSet
(
e
,
var_int_set_
);
expect_
.
push_back
(
if_depth_
==
1
);
sequence
.
push_back
(
0
);
int_sets_
.
push_back
(
int_set
);
expect_tx_count_
+=
1
;
}
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_arrive_barrier
()))
{
sequence
.
push_back
(
1
);
}
else
if
(
op
->
op
.
same_as
(
builtin
::
ptx_cp_async_barrier
()))
{
...
...
@@ -336,32 +338,61 @@ public:
class
BarrierCreationRewriter
:
public
StmtExprMutator
{
public:
BarrierCreationRewriter
(
std
::
vector
<
int
>
restore_barrier_ids
,
PrimExpr
producer_thread_extent
)
PrimExpr
producer_thread_extent
,
int
ensure_min_count
=
0
,
PrimExpr
default_barrier_thread_count
=
1
)
:
restore_barrier_ids_
(
std
::
move
(
restore_barrier_ids
)),
producer_thread_extent_
(
std
::
move
(
producer_thread_extent
))
{}
producer_thread_extent_
(
std
::
move
(
producer_thread_extent
)),
ensure_min_count_
(
ensure_min_count
),
default_barrier_thread_count_
(
std
::
move
(
default_barrier_thread_count
))
{
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
if
(
op
->
op
.
same_as
(
create_list_of_mbarrier
()))
{
std
::
vector
<
bool
>
tmp_
(
op
->
args
.
size
(),
false
);
Array
<
PrimExpr
>
new_args
;
size_t
cur_n
=
op
->
args
.
size
();
size_t
need_n
=
std
::
max
<
size_t
>
(
cur_n
,
static_cast
<
size_t
>
(
ensure_min_count_
));
// Mark barriers to restore across the full needed length, not just the
// original length, so newly appended entries can be restored as well.
std
::
vector
<
bool
>
replace
(
need_n
,
false
);
for
(
auto
&
id
:
restore_barrier_ids_
)
{
tmp_
[
id
]
=
true
;
if
(
id
>=
0
&&
static_cast
<
size_t
>
(
id
)
<
replace
.
size
())
{
replace
[
id
]
=
true
;
}
}
for
(
size_t
i
{
0
};
i
<
op
->
args
.
size
();
++
i
)
{
if
(
tmp_
[
i
])
{
Array
<
PrimExpr
>
new_args
;
new_args
.
reserve
(
need_n
);
// Preserve/override existing entries
for
(
size_t
i
{
0
};
i
<
cur_n
;
++
i
)
{
if
(
replace
[
i
])
{
new_args
.
push_back
(
producer_thread_extent_
);
}
else
{
new_args
.
push_back
(
op
->
args
[
i
]);
}
}
// Append additional barriers if required
for
(
size_t
i
=
cur_n
;
i
<
need_n
;
++
i
)
{
if
(
replace
[
i
])
{
new_args
.
push_back
(
producer_thread_extent_
);
}
else
{
new_args
.
push_back
(
default_barrier_thread_count_
);
}
}
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
else
{
return
StmtExprMutator
::
VisitExpr_
(
op
);
}
}
private:
std
::
vector
<
int
>
restore_barrier_ids_
;
PrimExpr
producer_thread_extent_
;
int
ensure_min_count_
{
0
};
PrimExpr
default_barrier_thread_count_
{
1
};
};
// we trust mbarrier_wait_parity to be correct
...
...
@@ -398,15 +429,38 @@ public:
collector
.
barrier_id_to_range
(),
has_create_list_of_mbarrier
);
f
.
CopyOnWrite
()
->
body
=
rewriter
(
f
->
body
);
// Compute the minimum number of barriers actually referenced in the body
// after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
struct
GetMbarrierMaxIdxCollector
:
public
StmtExprVisitor
{
int
max_idx
{
-
1
};
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
get_mbarrier
()))
{
if
(
op
->
args
.
size
()
==
1
)
{
if
(
const
auto
*
imm
=
op
->
args
[
0
].
as
<
IntImmNode
>
())
{
max_idx
=
std
::
max
(
max_idx
,
static_cast
<
int
>
(
imm
->
value
));
}
}
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
};
GetMbarrierMaxIdxCollector
max_idx_collector
;
max_idx_collector
(
f
->
body
);
int
ensure_min_count
=
max_idx_collector
.
max_idx
+
1
;
// 0-based -> count
// For simple TMA-only producers, default barrier arrive count should be 1
// (only the elected leader performs the TMA arrive/expect).
auto
barrier_creation_rewriter
=
BarrierCreationRewriter
(
rewriter
.
restore_barrier_ids_
,
rewriter
.
producer_thread_extent_
);
rewriter
.
restore_barrier_ids_
,
rewriter
.
producer_thread_extent_
,
ensure_min_count
,
Integer
(
1
));
f
.
CopyOnWrite
()
->
body
=
barrier_creation_rewriter
(
f
->
body
);
return
f
;
}
private:
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
{
auto
block
=
GetRef
<
Block
>
(
op
);
auto
block
=
tvm
::
ffi
::
GetRef
<
Block
>
(
op
);
if
(
!
has_create_list_of_mbarrier_
&&
!
barrier_id_to_range_
.
empty
()
&&
op
->
name_hint
==
MainBlockName
)
{
ICHECK
(
false
)
<<
"Please declare create_list_of_mbarrier."
;
...
...
@@ -452,10 +506,27 @@ private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
{
if
(
op
->
op
.
same_as
(
tma_load
())
||
op
->
op
.
same_as
(
tma_load_im2col
()))
{
// check this must be in the tma_op_to_barrier_id_
ICHECK
(
tma_op_to_barrier_id_
.
count
(
GetRef
<
Call
>
(
op
)))
<<
"tma_load must be in the tma_op_to_barrier_id_"
;
auto
barrier_id
=
tma_op_to_barrier_id_
[
GetRef
<
Call
>
(
op
)];
auto
call_ref
=
tvm
::
ffi
::
GetRef
<
Call
>
(
op
);
if
(
!
tma_op_to_barrier_id_
.
count
(
call_ref
))
{
// For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
// so codegen can emit mbarrier[index]. This handles degenerate
// producer-only kernels where no arrive() is seen and mapping is empty.
auto
arg0
=
op
->
args
[
0
].
as
<
Call
>
();
bool
is_1d_tma_load
=
arg0
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_descriptor
())
&&
!
arg0
.
value
()
->
op
.
same_as
(
create_tma_im2col_descriptor
());
if
(
is_1d_tma_load
&&
op
->
args
.
size
()
>=
3
)
{
if
(
const
auto
*
imm
=
op
->
args
[
2
].
as
<
IntImmNode
>
())
{
Array
<
PrimExpr
>
new_args
=
op
->
args
;
new_args
.
Set
(
2
,
Call
(
DataType
::
Handle
(),
get_mbarrier
(),
{
IntImm
(
DataType
::
Int
(
32
),
static_cast
<
int
>
(
imm
->
value
))}));
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
}
return
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
);
}
auto
barrier_id
=
tma_op_to_barrier_id_
[
call_ref
];
auto
new_args
=
op
->
args
;
auto
arg0
=
op
->
args
[
0
].
as
<
Call
>
();
auto
is_1d_tma_load
=
...
...
@@ -468,9 +539,11 @@ private:
}
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
else
if
(
op
->
op
.
same_as
(
mbarrier_expect_tx
()))
{
ICHECK
(
tma_op_to_barrier_id_
.
count
(
GetRef
<
Call
>
(
op
)))
<<
"mbarrier_expect_tx must be in the tma_op_to_barrier_id_"
;
auto
barrier_id
=
tma_op_to_barrier_id_
[
GetRef
<
Call
>
(
op
)];
auto
call_ref
=
tvm
::
ffi
::
GetRef
<
Call
>
(
op
);
if
(
!
tma_op_to_barrier_id_
.
count
(
call_ref
))
{
return
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
);
}
auto
barrier_id
=
tma_op_to_barrier_id_
[
call_ref
];
auto
new_args
=
op
->
args
;
new_args
.
Set
(
0
,
barrier_id
);
if
(
!
has_warp_specialization_
)
...
...
@@ -522,10 +595,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.InjectTmaBarrier"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.InjectTmaBarrier"
,
InjectTmaBarrier
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/layout_inference.cc
View file @
bbbf4207
...
...
@@ -11,6 +11,7 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <algorithm>
#include <queue>
#include "../layout/utils.h"
...
...
@@ -105,20 +106,60 @@ public:
"required for layout inference."
;
// Run InferLayout
DLOG
(
INFO
)
<<
"[RunInferStep] working on "
<<
cur_infer_id
<<
'\n'
;
auto
updates
=
next
->
InferLayout
(
LayoutInferArgs
{
target_
,
thread_bounds
,
layout_map
,
&
analyzer_
,
buffer_oob
},
level
);
// Process the returned updates
for
(
const
auto
&
[
buffer
,
layout
]
:
updates
)
{
DLOG
(
INFO
)
<<
" consider update "
<<
buffer
<<
" as "
<<
layout
->
DebugOutput
()
<<
'\n'
;
// Basic validity checks
ICHECK
(
buffer
.
defined
())
<<
"InferLayout returned an undefined buffer."
;
ICHECK
(
layout
.
defined
())
<<
"InferLayout returned an undefined layout."
;
// Helper: propagate inferred layout to alias buffers (same data Var)
auto
propagate_alias
=
[
&
](
const
Buffer
&
src_buffer
,
const
Layout
&
src_layout
)
{
if
(
!
buffer_data_to_buffers_
.
count
(
src_buffer
->
data
))
return
;
const
auto
&
siblings
=
buffer_data_to_buffers_
[
src_buffer
->
data
];
for
(
const
auto
&
sib
:
siblings
)
{
if
(
sib
.
same_as
(
src_buffer
))
continue
;
bool
shapes_equal
=
src_layout
->
InputShape
().
size
()
==
sib
->
shape
.
size
();
if
(
shapes_equal
)
{
for
(
size_t
i
=
0
;
i
<
src_layout
->
InputShape
().
size
();
++
i
)
{
if
(
!
analyzer_
.
CanProveEqual
(
src_layout
->
InputShape
()[
i
],
sib
->
shape
[
i
]))
{
shapes_equal
=
false
;
break
;
}
}
}
Layout
target_layout
=
shapes_equal
?
src_layout
:
src_layout
->
Reshape
(
sib
->
shape
,
&
analyzer_
);
if
(
layout_map
.
count
(
sib
))
{
ICHECK
(
target_layout
->
IsEqual
(
layout_map
[
sib
].
get
()))
<<
"Get different layout for alias buffer "
<<
sib
<<
" (data-shared with "
<<
src_buffer
<<
")
\n
current: "
<<
target_layout
->
DebugOutput
()
<<
"
\n
previous: "
<<
layout_map
[
sib
]
->
DebugOutput
();
}
else
{
layout_map
.
Set
(
sib
,
target_layout
);
if
(
update_queue
&&
use_list_
.
count
(
sib
))
{
for
(
int
idx
:
use_list_
[
sib
])
{
if
(
!
in_queue
[
idx
]
&&
idx
!=
cur_infer_id
)
{
in_queue
[
idx
]
=
true
;
q
.
push
(
idx
);
}
}
}
}
}
};
if
(
layout_map
.
count
(
buffer
))
{
// If new layout contains the old one, update map
if
(
buffer
.
scope
()
==
"local.fragment"
&&
...
...
@@ -153,8 +194,8 @@ public:
if
(
ProveFragmentContains
(
src_layout
,
dst_layout
,
indices
,
indices
,
inner_analyzer
))
{
layout_map
.
Set
(
buffer
,
layout
);
DLOG
(
INFO
)
<<
" layout broadcast from "
<<
src_layout
->
DebugOutput
()
<<
", accepted"
<<
'\n'
;
// Propagate to alias buffers as well
propagate_alias
(
buffer
,
layout
)
;
continue
;
}
}
...
...
@@ -163,10 +204,13 @@ public:
<<
"Get different layout for "
<<
buffer
<<
"
\n
current layout: "
<<
layout
->
DebugOutput
()
<<
"
\n
previous layout: "
<<
layout_map
[
buffer
]
->
DebugOutput
();
// Ensure aliases are consistent too
propagate_alias
(
buffer
,
layout
);
}
else
{
// Otherwise, update map
layout_map
.
Set
(
buffer
,
layout
);
DLOG
(
INFO
)
<<
" new layout accepted"
<<
'\n'
;
// Propagate to alias buffers (may enqueue their users)
propagate_alias
(
buffer
,
layout
);
if
(
!
update_queue
)
continue
;
...
...
@@ -272,6 +316,46 @@ public:
// step 3: relax constraints to free and re-run
InferInFreeMode
(
layout_map
,
strict_layout_map
);
// step 4: finalize alias layouts by Var
// For each storage var, if any buffer in the group has a layout,
// propagate (reshape if needed) to the rest to ensure completeness.
for
(
const
auto
&
[
var
,
buffers
]
:
buffer_data_to_buffers_
)
{
// Find a representative with existing layout
Optional
<
Buffer
>
rep
;
Optional
<
Layout
>
rep_layout
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
layout_map
.
count
(
buf
))
{
rep
=
buf
;
rep_layout
=
layout_map
[
buf
];
break
;
}
}
if
(
!
rep_layout
.
defined
())
continue
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
!
layout_map
.
count
(
buf
))
{
bool
shapes_equal
=
rep_layout
.
value
()
->
InputShape
().
size
()
==
buf
->
shape
.
size
();
if
(
shapes_equal
)
{
for
(
size_t
i
=
0
;
i
<
rep_layout
.
value
()
->
InputShape
().
size
();
++
i
)
{
if
(
!
analyzer_
.
CanProveEqual
(
rep_layout
.
value
()
->
InputShape
()[
i
],
buf
->
shape
[
i
]))
{
shapes_equal
=
false
;
break
;
}
}
}
Layout
reshaped
=
shapes_equal
?
rep_layout
.
value
()
:
rep_layout
.
value
()
->
Reshape
(
buf
->
shape
,
&
analyzer_
);
layout_map
.
Set
(
buf
,
reshaped
);
}
}
}
// Check that all local.fragment buffers have inferred layouts
for
(
const
auto
&
[
buffer
,
_
]
:
use_list_
)
{
if
(
buffer
.
scope
()
==
"local.fragment"
)
{
...
...
@@ -314,7 +398,13 @@ public:
void
Collect
(
const
PrimFunc
&
f
)
{
for
(
const
auto
&
[
_
,
buffer
]
:
f
->
buffer_map
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
if
(
buffer_data_to_buffers_
.
count
(
buffer
->
data
))
{
auto
buffers
=
buffer_data_to_buffers_
[
buffer
->
data
];
buffers
.
push_back
(
buffer
);
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
buffers
);
}
else
{
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
{
buffer
});
}
}
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
...
...
@@ -324,13 +414,25 @@ public:
}
private:
Map
<
Var
,
Buffer
>
GetBufferMap
()
const
{
Map
<
Var
,
Buffer
>
buffer_map
;
for
(
const
auto
&
[
var
,
buffers
]
:
buffer_data_to_buffers_
)
{
// Use the first buffer for each var
// TODO(lei): phaseout buffer_map in future.
if
(
!
buffers
.
empty
())
{
buffer_map
.
Set
(
var
,
buffers
[
0
]);
}
}
return
buffer_map
;
}
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
// Do not analysis the call node to the global function.
if
(
op
->
op
.
as
<
GlobalVarNode
>
())
return
;
auto
p
=
ParseOperator
(
GetRef
<
Call
>
(
op
),
buffer_data_to_buffer_
);
auto
p
=
ParseOperator
(
tvm
::
ffi
::
GetRef
<
Call
>
(
op
),
GetBufferMap
()
);
if
(
p
.
defined
())
{
for
(
const
auto
&
arg
:
op
->
args
)
{
if
(
auto
buffer
=
getBufferFromAccessPtr
(
arg
))
{
...
...
@@ -381,7 +483,7 @@ private:
}
// Add the tile operator to infer_list_
infer_list_stmt_
.
push_back
(
GetRef
<
ObjectRef
>
(
op
));
infer_list_stmt_
.
push_back
(
tvm
::
ffi
::
GetRef
<
ObjectRef
>
(
op
));
infer_list_
.
push_back
(
std
::
move
(
p
));
}
}
...
...
@@ -394,12 +496,18 @@ private:
if
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
auto
var_opt
=
call
->
args
[
1
].
as
<
Var
>
();
if
(
!
var_opt
.
has_value
())
{
D
LOG
(
WARNING
)
<<
"[getBufferFromAccessPtr] args[1] is not a Var, type: "
LOG
(
WARNING
)
<<
"[getBufferFromAccessPtr] args[1] is not a Var, type: "
<<
call
->
args
[
1
]
->
GetTypeKey
();
return
std
::
nullopt
;
}
const
auto
&
var
=
var_opt
.
value
();
return
buffer_data_to_buffer_
[
var
];
if
(
buffer_data_to_buffers_
.
count
(
var
))
{
const
auto
&
buffers
=
buffer_data_to_buffers_
[
var
];
if
(
!
buffers
.
empty
())
{
return
buffers
[
0
];
// Return the first buffer
}
}
return
std
::
nullopt
;
}
else
if
(
call
->
op
.
same_as
(
RegionOp
::
Get
()))
{
return
call
->
args
[
0
].
as
<
BufferLoadNode
>
()
->
buffer
;
}
...
...
@@ -416,11 +524,11 @@ private:
void
VisitStmt_
(
const
ForNode
*
op
)
final
{
if
(
op
->
kind
==
ForKind
::
kParallel
)
{
auto
infer
=
ParallelOp
(
GetRef
<
For
>
(
op
));
auto
infer
=
ParallelOp
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
));
for
(
const
auto
&
[
buffer
,
_
]
:
infer
->
GetIndiceMap
())
{
addToUseList
(
buffer
);
}
infer_list_stmt_
.
push_back
(
GetRef
<
ObjectRef
>
(
op
));
infer_list_stmt_
.
push_back
(
tvm
::
ffi
::
GetRef
<
ObjectRef
>
(
op
));
infer_list_
.
push_back
(
std
::
move
(
infer
));
thread_var_vec_
.
push_back
(
thread_var_
);
if
(
thread_var_
.
defined
()
&&
...
...
@@ -442,21 +550,55 @@ private:
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
for
(
auto
buffer
:
op
->
alloc_buffers
)
{
buffer_data_to_buffer_
.
Set
(
buffer
->
data
,
buffer
);
if
(
buffer_data_to_buffers_
.
count
(
buffer
->
data
))
{
auto
buffers
=
buffer_data_to_buffers_
[
buffer
->
data
];
buffers
.
push_back
(
buffer
);
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
buffers
);
}
else
{
buffer_data_to_buffers_
.
Set
(
buffer
->
data
,
{
buffer
});
}
}
// First, visit the block body to collect all buffers from
// BufferLoad/BufferStore
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
// After visiting, apply layouts to all collected buffers
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
// Check if the layout map is Map<Var, Layout>
auto
map
=
op
->
annotations
.
Get
(
attr
::
kLayoutMap
)
->
as
<
Map
<
Var
,
Layout
>>
().
value
();
for
(
const
auto
&
[
var
,
layout
]
:
map
)
{
ICHECK
(
buffer_data_to_buffer_
.
count
(
var
))
ICHECK
(
buffer_data_to_buffer
s
_
.
count
(
var
))
<<
"buffer "
<<
var
<<
" is not found in the block"
;
auto
buffer
=
buffer_data_to_buffer_
[
var
];
ICHECK
(
StructuralEqual
()(
layout
->
InputShape
(),
buffer
->
shape
));
const
auto
&
buffers
=
buffer_data_to_buffers_
[
var
];
ICHECK
(
!
buffers
.
empty
())
<<
"buffer list for "
<<
var
<<
" is empty"
;
// Apply layout to all buffers associated with this var
for
(
const
auto
&
buffer
:
buffers
)
{
// Reshape the layout to match the buffer's shape
// Check if shapes are structurally equal
bool
shapes_equal
=
layout
->
InputShape
().
size
()
==
buffer
->
shape
.
size
();
if
(
shapes_equal
)
{
for
(
size_t
i
=
0
;
i
<
layout
->
InputShape
().
size
();
++
i
)
{
if
(
!
analyzer_
.
CanProveEqual
(
layout
->
InputShape
()[
i
],
buffer
->
shape
[
i
]))
{
shapes_equal
=
false
;
break
;
}
}
}
if
(
shapes_equal
)
{
annotated_layout_map_
.
Set
(
buffer
,
layout
);
}
else
{
auto
reshaped_layout
=
layout
->
Reshape
(
buffer
->
shape
,
&
analyzer_
);
annotated_layout_map_
.
Set
(
buffer
,
reshaped_layout
);
}
}
}
}
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
...
...
@@ -470,7 +612,67 @@ private:
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
}
Map
<
Var
,
Buffer
>
buffer_data_to_buffer_
;
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
// Collect buffer from BufferLoad
if
(
op
->
buffer
.
defined
()
&&
op
->
buffer
->
data
.
defined
())
{
if
(
buffer_data_to_buffers_
.
count
(
op
->
buffer
->
data
))
{
// Check if this buffer is already in the list
auto
buffers
=
buffer_data_to_buffers_
[
op
->
buffer
->
data
];
bool
found
=
false
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
buf
.
same_as
(
op
->
buffer
))
{
found
=
true
;
break
;
}
}
if
(
!
found
)
{
buffers
.
push_back
(
op
->
buffer
);
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
buffers
);
DLOG
(
INFO
)
<<
"[LayoutInference] BufferLoad: added buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
else
{
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
{
op
->
buffer
});
DLOG
(
INFO
)
<<
"[LayoutInference] BufferLoad: new buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
// Collect buffer from BufferStore
if
(
op
->
buffer
.
defined
()
&&
op
->
buffer
->
data
.
defined
())
{
if
(
buffer_data_to_buffers_
.
count
(
op
->
buffer
->
data
))
{
// Check if this buffer is already in the list
auto
buffers
=
buffer_data_to_buffers_
[
op
->
buffer
->
data
];
bool
found
=
false
;
for
(
const
auto
&
buf
:
buffers
)
{
if
(
buf
.
same_as
(
op
->
buffer
))
{
found
=
true
;
break
;
}
}
if
(
!
found
)
{
buffers
.
push_back
(
op
->
buffer
);
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
buffers
);
DLOG
(
INFO
)
<<
"[LayoutInference] BufferStore: added buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
else
{
buffer_data_to_buffers_
.
Set
(
op
->
buffer
->
data
,
{
op
->
buffer
});
DLOG
(
INFO
)
<<
"[LayoutInference] BufferStore: new buffer "
<<
op
->
buffer
<<
" buffer.get() = "
<<
op
->
buffer
.
get
()
<<
" data = "
<<
op
->
buffer
->
data
.
get
();
}
}
IRVisitorWithAnalyzer
::
VisitStmt_
(
op
);
}
Map
<
Var
,
Array
<
Buffer
>>
buffer_data_to_buffers_
;
std
::
vector
<
ObjectRef
>
infer_list_stmt_
;
std
::
vector
<
TileOperator
>
infer_list_
;
std
::
unordered_map
<
Buffer
,
std
::
vector
<
int
>
,
ObjectPtrHash
,
ObjectPtrEqual
>
...
...
@@ -513,12 +715,33 @@ private:
if
(
infer_indices
.
empty
())
continue
;
// Union all infer_list_ indices that share the same
b
uffer
// Union all infer_list_ indices that share the same
B
uffer
object
int
first_idx
=
infer_indices
[
0
];
for
(
size_t
i
=
1
;
i
<
infer_indices
.
size
();
i
++
)
{
uf
.
Union
(
first_idx
,
infer_indices
[
i
]);
}
}
// Additionally, union across buffers that share the same underlying
// buffer->data (Var). This handles cases like reshape where multiple
// Buffer objects alias the same storage.
for
(
const
auto
&
[
var
,
buffers
]
:
buffer_data_to_buffers_
)
{
std
::
vector
<
int
>
merged
;
for
(
const
auto
&
buf
:
buffers
)
{
auto
it
=
use_list_
.
find
(
buf
);
if
(
it
!=
use_list_
.
end
())
{
const
auto
&
vec
=
it
->
second
;
merged
.
insert
(
merged
.
end
(),
vec
.
begin
(),
vec
.
end
());
}
}
if
(
merged
.
size
()
>
1
)
{
std
::
sort
(
merged
.
begin
(),
merged
.
end
());
merged
.
erase
(
std
::
unique
(
merged
.
begin
(),
merged
.
end
()),
merged
.
end
());
int
first
=
merged
[
0
];
for
(
size_t
i
=
1
;
i
<
merged
.
size
();
++
i
)
{
uf
.
Union
(
first
,
merged
[
i
]);
}
}
}
std
::
unordered_map
<
int
,
std
::
vector
<
int
>>
components
;
for
(
int
i
=
0
;
i
<
infer_list_
.
size
();
i
++
)
{
int
root
=
uf
.
Find
(
i
);
...
...
@@ -597,7 +820,9 @@ private:
}
}
// Update the best plan if this one uses fewer registers
if
(
reg_num
<
min_reg_num
)
{
if
(
reg_num
<
min_reg_num
||
(
reg_num
==
min_reg_num
&&
attempt_infer_root
<
min_reg_num_infer_root
))
{
best_infer_list
=
BackupInferList
();
// Use backup to avoid moving out infer_list_
best_layout_map
=
tmp_layout_map
;
...
...
@@ -711,8 +936,8 @@ private:
.
value
();
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
if
(
result_
.
for_map
.
count
(
GetRef
<
For
>
(
op
)))
{
auto
root
=
GetRef
<
For
>
(
op
);
if
(
result_
.
for_map
.
count
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
)))
{
auto
root
=
tvm
::
ffi
::
GetRef
<
For
>
(
op
);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
...
...
@@ -787,7 +1012,18 @@ private:
}
});
if
(
has_non_local
&&
!
has_reducer
)
{
// If a cast operation exists, vectorization may still be required
bool
has_cast_operations
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
// Check if this is a non-reducer store with Cast operation
if
(
store
->
value
.
as
<
CastNode
>
())
{
has_cast_operations
=
true
;
}
}
});
if
((
has_non_local
||
has_cast_operations
)
&&
!
has_reducer
)
{
for_node
=
VectorizeLoop
(
for_node
);
}
...
...
@@ -831,10 +1067,10 @@ tvm::transform::Pass LayoutInference() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LayoutInference"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LayoutInference"
,
LayoutInference
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/layout_reducer.cc
View file @
bbbf4207
...
...
@@ -14,6 +14,7 @@
#include "../layout/layout.h"
#include "../op/fill.h"
#include "../op/finalize_reducer.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"
...
...
@@ -275,17 +276,34 @@ private:
auto
op
=
op_ref
.
CopyOnWrite
();
if
(
op
->
op
.
same_as
(
Fill
::
Get
()))
{
ICHECK
(
!
op
->
args
.
empty
());
if
(
auto
arg0_call
=
op
->
args
[
0
].
as
<
Call
>
();
arg0_call
&&
arg0_call
.
value
()
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
if
(
auto
arg0_call
=
op
->
args
[
0
].
as
<
Call
>
())
{
// Case 1: tl.region(...) — extract buffer var from its first arg
if
(
arg0_call
.
value
()
->
op
.
same_as
(
RegionOp
::
Get
()))
{
ICHECK
(
!
arg0_call
.
value
()
->
args
.
empty
());
if
(
auto
bl
=
arg0_call
.
value
()
->
args
[
0
].
as
<
BufferLoadNode
>
())
{
Var
var
=
bl
->
buffer
->
data
;
if
(
reducer_info_map_
.
count
(
var
))
{
ICHECK
(
inside_reducer_range_
.
count
(
var
)
==
0
)
<<
"T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next."
;
inside_reducer_range_
.
Set
(
var
,
reducer_info_map_
.
Get
(
var
).
value
());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else
if
(
arg0_call
.
value
()
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
ICHECK
(
arg0_call
.
value
()
->
args
.
size
()
>
1
);
if
(
auto
var
=
arg0_call
.
value
()
->
args
[
1
].
as
<
Var
>
();
var
&&
reducer_info_map_
.
count
(
var
.
value
()))
{
ICHECK
(
inside_reducer_range_
.
count
(
var
.
value
())
==
0
)
<<
"T.fill on reducer must be enclosed with a T.finalize_reducer "
<<
"T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next."
;
inside_reducer_range_
.
Set
(
var
.
value
(),
reducer_info_map_
.
Get
(
var
.
value
()).
value
());
inside_reducer_range_
.
Set
(
var
.
value
(),
reducer_info_map_
.
Get
(
var
.
value
()).
value
());
}
}
}
}
else
if
(
op
->
op
.
same_as
(
FinalizeReducerOp
::
Get
()))
{
...
...
@@ -362,10 +380,10 @@ tvm::transform::Pass LayoutReducer() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LayoutReducer"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LayoutReducer"
,
LayoutReducer
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/layout_reducer.h
View file @
bbbf4207
...
...
@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object {
ReducerInfoNode
()
=
default
;
ReducerInfoNode
(
const
String
&
op_str
,
const
String
&
rep_str
);
static
constexpr
const
char
*
_type_key
=
"tl.ReducerInfo"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ReducerInfoNode
,
Object
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.ReducerInfo"
,
ReducerInfoNode
,
Object
);
};
struct
ReducerInfo
:
ObjectRef
{
public:
TVM_DLL
ReducerInfo
(
const
String
&
op_str
,
const
String
&
rep_str
)
{
data_
=
make_object
<
ReducerInfoNode
>
(
op_str
,
rep_str
);
data_
=
tvm
::
ffi
::
make_object
<
ReducerInfoNode
>
(
op_str
,
rep_str
);
}
TVM_DEFINE_OBJECT_REF_METHODS
(
ReducerInfo
,
ObjectRef
,
ReducerInfoNode
);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE
(
ReducerInfo
,
ObjectRef
,
ReducerInfoNode
);
};
namespace
attr
{
...
...
src/transform/legalize_negative_index.cc
0 → 100644
View file @
bbbf4207
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <vector>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
arith
::
IRVisitorWithAnalyzer
;
enum
class
IndexSignState
{
kNonNegative
,
kNegative
,
kUnknown
};
class
NegativeIndexAnalyzer
:
public
IRVisitorWithAnalyzer
{
public:
explicit
NegativeIndexAnalyzer
(
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
*
result
)
:
result_
(
result
)
{}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
);
std
::
vector
<
IndexSignState
>
states
;
states
.
reserve
(
op
->
indices
.
size
());
bool
needs_record
=
false
;
for
(
size_t
i
=
0
;
i
<
op
->
indices
.
size
();
++
i
)
{
PrimExpr
simplified
=
analyzer_
.
Simplify
(
op
->
indices
[
i
]);
// Handle scalar indices with the standard analyzer
if
(
simplified
.
dtype
().
lanes
()
==
1
)
{
if
(
analyzer_
.
CanProve
(
simplified
>=
0
))
{
states
.
push_back
(
IndexSignState
::
kNonNegative
);
continue
;
}
if
(
analyzer_
.
CanProve
(
simplified
<
0
))
{
states
.
push_back
(
IndexSignState
::
kNegative
);
needs_record
=
true
;
continue
;
}
states
.
push_back
(
IndexSignState
::
kUnknown
);
needs_record
=
true
;
DLOG
(
WARNING
)
<<
"LegalizeNegativeIndex: cannot prove non-negative index "
<<
simplified
<<
" for buffer "
<<
load
->
buffer
->
name
<<
" (axis "
<<
i
<<
")."
;
continue
;
}
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
IndexSignState
vec_state
=
IndexSignState
::
kUnknown
;
if
(
const
auto
*
ramp
=
simplified
.
as
<
RampNode
>
())
{
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto
base_bound
=
analyzer_
.
const_int_bound
(
ramp
->
base
);
auto
stride_bound
=
analyzer_
.
const_int_bound
(
ramp
->
stride
);
int
lanes
=
*
as_const_int
(
ramp
->
lanes
);
int64_t
base_min
=
base_bound
->
min_value
;
int64_t
base_max
=
base_bound
->
max_value
;
int64_t
s_min
=
stride_bound
->
min_value
;
int64_t
s_max
=
stride_bound
->
max_value
;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t
lower
=
base_min
;
if
(
s_min
<
0
)
lower
+=
s_min
*
(
lanes
-
1
);
int64_t
upper
=
base_max
;
if
(
s_max
>
0
)
upper
+=
s_max
*
(
lanes
-
1
);
if
(
lower
>=
0
)
{
vec_state
=
IndexSignState
::
kNonNegative
;
}
else
if
(
upper
<
0
)
{
vec_state
=
IndexSignState
::
kNegative
;
}
else
{
vec_state
=
IndexSignState
::
kUnknown
;
}
}
else
if
(
const
auto
*
bc
=
simplified
.
as
<
BroadcastNode
>
())
{
auto
v
=
analyzer_
.
Simplify
(
bc
->
value
);
if
(
analyzer_
.
CanProve
(
v
>=
0
))
{
vec_state
=
IndexSignState
::
kNonNegative
;
}
else
if
(
analyzer_
.
CanProve
(
v
<
0
))
{
vec_state
=
IndexSignState
::
kNegative
;
}
else
{
// Try const bound if proof unavailable
auto
vb
=
analyzer_
.
const_int_bound
(
v
);
if
(
vb
->
min_value
>=
0
)
{
vec_state
=
IndexSignState
::
kNonNegative
;
}
else
if
(
vb
->
max_value
<
0
)
{
vec_state
=
IndexSignState
::
kNegative
;
}
else
{
vec_state
=
IndexSignState
::
kUnknown
;
}
}
}
if
(
vec_state
==
IndexSignState
::
kNonNegative
)
{
states
.
push_back
(
IndexSignState
::
kNonNegative
);
continue
;
}
if
(
vec_state
==
IndexSignState
::
kNegative
)
{
states
.
push_back
(
IndexSignState
::
kNegative
);
needs_record
=
true
;
continue
;
}
states
.
push_back
(
IndexSignState
::
kUnknown
);
needs_record
=
true
;
DLOG
(
WARNING
)
<<
"LegalizeNegativeIndex: cannot prove non-negative index "
<<
simplified
<<
" for buffer "
<<
load
->
buffer
->
name
<<
" (axis "
<<
i
<<
")."
;
}
if
(
needs_record
)
{
(
*
result_
)[
op
]
=
std
::
move
(
states
);
}
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
}
private:
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
*
result_
;
};
class
NegativeIndexRewriter
:
public
arith
::
IRMutatorWithAnalyzer
{
public:
static
PrimFunc
Apply
(
PrimFunc
func
,
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states
)
{
arith
::
Analyzer
analyzer
;
NegativeIndexRewriter
rewriter
(
&
analyzer
,
states
);
if
(
!
func
->
body
.
defined
())
{
return
func
;
}
PrimFuncNode
*
func_node
=
func
.
CopyOnWrite
();
func_node
->
body
=
rewriter
.
VisitStmt
(
func_node
->
body
);
return
func
;
}
private:
NegativeIndexRewriter
(
arith
::
Analyzer
*
analyzer
,
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
),
states_
(
states
)
{}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
arith
::
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
auto
it
=
states_
.
find
(
op
);
if
(
it
==
states_
.
end
())
{
return
load
;
}
auto
indices
=
load
->
indices
;
bool
changed
=
false
;
const
auto
&
state_vector
=
it
->
second
;
ICHECK_EQ
(
state_vector
.
size
(),
indices
.
size
())
<<
"State vector size mismatch for buffer load "
<<
load
->
buffer
->
name
;
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
if
(
state_vector
[
i
]
!=
IndexSignState
::
kNegative
)
{
continue
;
}
PrimExpr
extent
=
load
->
buffer
->
shape
[
i
];
indices
.
Set
(
i
,
analyzer_
->
Simplify
(
extent
+
indices
[
i
]));
changed
=
true
;
}
if
(
!
changed
)
{
return
load
;
}
return
BufferLoad
(
load
->
buffer
,
indices
);
}
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states_
;
};
PrimFunc
LegalizeNegativeIndex
(
PrimFunc
func
)
{
if
(
!
func
->
body
.
defined
())
{
return
func
;
}
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
states
;
NegativeIndexAnalyzer
analyzer
(
&
states
);
analyzer
(
func
->
body
);
if
(
states
.
empty
())
{
return
func
;
}
return
NegativeIndexRewriter
::
Apply
(
std
::
move
(
func
),
states
);
}
tvm
::
transform
::
Pass
LegalizeNegativeIndexPass
()
{
using
namespace
tir
::
transform
;
auto
pass_func
=
[](
PrimFunc
f
,
const
IRModule
&
,
PassContext
)
{
return
LegalizeNegativeIndex
(
std
::
move
(
f
));
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LegalizeNegativeIndex"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeNegativeIndex"
,
LegalizeNegativeIndexPass
);
}
}
// namespace tl
}
// namespace tvm
src/transform/legalize_safe_memory_access.cc
View file @
bbbf4207
...
...
@@ -38,7 +38,7 @@ private:
StmtVisitor
::
VisitStmt
(
op
->
body
);
if
(
!
has_child_for_
)
{
leaf_for_nodes
.
push_back
(
GetRef
<
For
>
(
op
));
leaf_for_nodes
.
push_back
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
));
}
parent_has_child_for_
=
parent_has_child_for
;
...
...
@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeSafeMemoryAccess"
,
LegalizeSafeMemoryAccess
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/legalize_vectorized_loop.cc
View file @
bbbf4207
...
...
@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeVectorizedLoop"
,
LegalizeVectorizedLoop
);
}
);
}
}
// namespace tl
}
// namespace tvm
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment