Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
596 additions
and
933 deletions
+596
-933
src/transform/loop_partition.cc
src/transform/loop_partition.cc
+1
-1
src/transform/loop_vectorize.cc
src/transform/loop_vectorize.cc
+3
-2
src/transform/loop_vectorize_dynamic.cc
src/transform/loop_vectorize_dynamic.cc
+0
-545
src/transform/lower_device_kernel_launch.cc
src/transform/lower_device_kernel_launch.cc
+7
-7
src/transform/lower_device_storage_access_info.cc
src/transform/lower_device_storage_access_info.cc
+3
-3
src/transform/lower_hopper_intrin.cc
src/transform/lower_hopper_intrin.cc
+4
-4
src/transform/lower_intrin.cc
src/transform/lower_intrin.cc
+9
-8
src/transform/lower_l2_persistent_annotation.cc
src/transform/lower_l2_persistent_annotation.cc
+2
-2
src/transform/lower_opaque_block.cc
src/transform/lower_opaque_block.cc
+23
-5
src/transform/lower_shared_barrier.cc
src/transform/lower_shared_barrier.cc
+3
-3
src/transform/lower_shared_tmem.cc
src/transform/lower_shared_tmem.cc
+15
-4
src/transform/lower_thread_allreduce.cc
src/transform/lower_thread_allreduce.cc
+3
-2
src/transform/lower_tile_op.cc
src/transform/lower_tile_op.cc
+32
-69
src/transform/make_packed_api.cc
src/transform/make_packed_api.cc
+8
-7
src/transform/merge_if_stmt.cc
src/transform/merge_if_stmt.cc
+2
-2
src/transform/merge_shared_memory_allocations.cc
src/transform/merge_shared_memory_allocations.cc
+459
-248
src/transform/multi_version_buffer_rewriter.cc
src/transform/multi_version_buffer_rewriter.cc
+5
-4
src/transform/persist_threadblock.cc
src/transform/persist_threadblock.cc
+2
-2
src/transform/pipeline_planning.cc
src/transform/pipeline_planning.cc
+6
-6
src/transform/simplify.cc
src/transform/simplify.cc
+9
-9
No files found.
src/transform/loop_partition.cc
View file @
bbbf4207
...
...
@@ -173,7 +173,7 @@ private:
if
(
as_const_int
(
analyzer
->
Simplify
(
node
->
extent
))
==
nullptr
)
{
return
StmtExprMutator
::
VisitStmt_
(
node
);
}
For
new_for
=
GetRef
<
For
>
(
node
);
For
new_for
=
tvm
::
ffi
::
GetRef
<
For
>
(
node
);
auto
for_ptr
=
new_for
.
CopyOnWrite
();
for_ptr
->
annotations
.
Set
(
tir
::
attr
::
pragma_unroll_explicit
,
Bool
(
false
));
for_ptr
->
kind
=
ForKind
::
kUnrolled
;
...
...
src/transform/loop_vectorize.cc
View file @
bbbf4207
...
...
@@ -240,8 +240,9 @@ int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
bool
CanProveIndependent
(
const
PrimExpr
&
expr
,
Var
var
,
arith
::
Analyzer
*
analyzer
)
{
// 1. if var doesn't exist, it is independent
bool
used_var
=
UsesVar
(
expr
,
[
&
](
const
VarNode
*
v
)
{
return
GetRef
<
Var
>
(
v
).
same_as
(
var
);
});
bool
used_var
=
UsesVar
(
expr
,
[
&
](
const
VarNode
*
v
)
{
return
tvm
::
ffi
::
GetRef
<
Var
>
(
v
).
same_as
(
var
);
});
if
(
!
used_var
)
{
return
true
;
}
...
...
src/transform/loop_vectorize_dynamic.cc
deleted
100644 → 0
View file @
8f4628e0
/*!
* \file loop_vectorize_dynamic.cc
* \brief A tool to automatically vectorize a for loop with dynamic shape
* \brief Reference to loop_vectorize.cc and vectorize_loop.cc
*/
#include <cstdint>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <numeric>
#include <utility>
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../op/builtin.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
arith
::
IRMutatorWithAnalyzer
;
struct
VectorizePlanResult
{
int
vector_size
;
bool
dynamic
;
PrimExpr
condition
;
};
bool
IndiceCanVectorizeDynamic
(
const
PrimExpr
&
expr
,
Var
var
,
const
PrimExpr
&
iter_var_size
,
int
target_vectorized_size
,
arith
::
Analyzer
*
analyzer
)
{
ICHECK
(
target_vectorized_size
>=
1
);
if
(
target_vectorized_size
==
1
)
return
true
;
if
(
!
analyzer
->
CanProveEqual
(
FloorMod
(
iter_var_size
,
target_vectorized_size
),
0
))
return
false
;
Var
v0
(
"v0"
),
v1
(
"v1"
);
analyzer
->
Bind
(
v0
,
Range
(
0
,
target_vectorized_size
));
analyzer
->
Bind
(
v1
,
Range
(
0
,
FloorDiv
(
iter_var_size
,
target_vectorized_size
)));
PrimExpr
expr_transformed
=
analyzer
->
Simplify
(
Substitute
(
expr
,
{{
var
,
v0
+
v1
*
target_vectorized_size
}}));
Vectorizer
vectorizer
(
v0
,
IntImm
(
v0
->
dtype
,
target_vectorized_size
));
PrimExpr
expr_vectorized
=
vectorizer
.
VisitExpr
(
expr_transformed
);
auto
ramp_node
=
expr_vectorized
.
as
<
RampNode
>
();
if
(
!
ramp_node
)
{
// Broadcast value
if
(
expr_vectorized
.
dtype
().
lanes
()
==
1
)
return
true
;
else
return
false
;
}
else
{
return
is_one
(
ramp_node
->
stride
);
}
}
class
VectorizePlannerDynamic
:
public
arith
::
IRVisitorWithAnalyzer
{
public:
VectorizePlannerDynamic
(
int
dynamic_alignment
,
bool
disable_dynamic_tail_split
)
:
dynamic_alignment_
(
dynamic_alignment
),
disable_dynamic_tail_split_
(
disable_dynamic_tail_split
),
vector_load_bits_max_
(
128
)
{
if
(
disable_dynamic_tail_split_
)
{
vector_size_
=
dynamic_alignment_
;
}
else
{
vector_size_
=
vector_load_bits_max_
;
}
}
int
Plan
(
const
For
&
node
)
{
this
->
operator
()(
node
);
// Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1;
return
vector_size_
;
}
bool
GetDynamic
()
{
return
dynamic_
;
}
PrimExpr
GetCondition
()
{
return
condition_
;
}
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
inner_for_
=
node
;
iter_map_
.
Set
(
node
->
loop_var
,
Range
(
node
->
min
,
node
->
extent
));
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
void
VisitExpr_
(
const
BufferLoadNode
*
node
)
final
{
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
node
->
buffer
.
scope
()
==
"shared.dyn"
)
has_nonlocal_memory_access_
=
true
;
if
(
node
->
buffer
->
shape
.
size
()
==
1
)
{
// TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register.
auto
boundary_check
=
node
->
buffer
->
shape
[
0
].
as
<
IntImmNode
>
();
if
(
boundary_check
&&
boundary_check
->
value
==
1
)
{
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
}
UpdateVectorSize
(
node
->
indices
,
node
->
buffer
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
void
VisitStmt_
(
const
BufferStoreNode
*
node
)
final
{
if
(
node
->
buffer
.
scope
()
==
"shared"
||
node
->
buffer
.
scope
()
==
"global"
||
node
->
buffer
.
scope
()
==
"shared.dyn"
)
has_nonlocal_memory_access_
=
true
;
UpdateVectorSize
(
node
->
indices
,
node
->
buffer
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
void
VisitStmt_
(
const
IfThenElseNode
*
node
)
final
{
CheckConditionVectorized
(
node
->
condition
);
return
arith
::
IRVisitorWithAnalyzer
::
VisitStmt_
(
node
);
}
void
VisitExpr_
(
const
CallNode
*
node
)
final
{
if
(
node
->
op
==
builtin
::
if_then_else
())
{
CheckConditionVectorized
(
node
->
args
[
0
]);
}
else
if
(
node
->
op
==
builtin
::
call_extern
())
{
// do not vectorize extern calls
vector_size_
=
1
;
}
return
arith
::
IRVisitorWithAnalyzer
::
VisitExpr_
(
node
);
}
void
CheckConditionVectorized
(
const
PrimExpr
&
cond
)
{
// TODO: may perform some checks here
}
void
UpdateVectorSize
(
const
Array
<
PrimExpr
>
&
indices
,
const
Buffer
&
buffer
)
{
if
(
!
inner_for_
)
return
;
auto
extent_ptr
=
inner_for_
->
extent
.
as
<
IntImmNode
>
();
if
(
!
extent_ptr
)
return
;
const
DataType
&
access_type
=
buffer
->
dtype
;
// i // 2, i % 8 can also be vectorized as factor 16
int
max_vector_size
=
vector_load_bits_max_
/
access_type
.
bits
();
// so we should disable this GCD optimization
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
extent_ptr
->
value
);
auto
last_dim
=
buffer
->
shape
.
back
();
auto
mod_set
=
analyzer_
.
modular_set
(
last_dim
);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if
(
buffer
->
shape
.
back
().
as
<
IntImmNode
>
())
{
max_vector_size
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
coeff
);
auto
gcd_base
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
mod_set
->
base
);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if
(
gcd_base
<
Downcast
<
IntImm
>
(
last_dim
)
->
value
)
{
max_vector_size
=
gcd_base
;
}
vector_size_
=
arith
::
ZeroAwareGCD
(
max_vector_size
,
vector_size_
);
PrimExpr
elem_offset
=
0
;
PrimExpr
stride
=
1
;
for
(
int
i
=
indices
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
elem_offset
=
elem_offset
+
indices
[
i
]
*
stride
;
stride
=
stride
*
buffer
->
shape
[
i
];
}
while
(
!
IndiceCanVectorizeDynamic
(
elem_offset
,
inner_for_
->
loop_var
,
inner_for_
->
extent
,
vector_size_
,
&
analyzer_
))
{
vector_size_
/=
2
;
}
}
else
{
// dynamic shape load: get the vectorization condition
dynamic_
=
true
;
if
(
!
disable_dynamic_tail_split_
&&
vector_size_
>=
vector_load_bits_max_
/
buffer
->
dtype
.
bits
())
{
vector_size_
=
vector_load_bits_max_
/
buffer
->
dtype
.
bits
();
}
PrimExpr
offset
=
buffer
.
OffsetOf
(
indices
).
back
();
// condition for alignment, maybe useless
condition_
=
(
FloorMod
(
offset
,
vector_size_
)
==
0
);
}
}
// Use dynamic alignment from pass config
int
vector_load_bits_max_
;
int
dynamic_alignment_
;
bool
disable_dynamic_tail_split_
;
int
vector_size_
;
const
ForNode
*
inner_for_
{};
Map
<
Var
,
Range
>
iter_map_
;
bool
has_nonlocal_memory_access_
=
false
;
// conditionally vectorize
bool
dynamic_
=
false
;
PrimExpr
condition_
;
};
class
VectorizedBodyMutator
:
public
StmtExprMutator
{
public:
VectorizedBodyMutator
(
Var
inner_var
,
int
vector_size
,
std
::
vector
<
PrimExpr
>
conditions
)
:
inner_var_
(
std
::
move
(
inner_var
)),
vector_size_
(
vector_size
),
conditions_
(
std
::
move
(
conditions
))
{}
private:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
// TODO: Currently not ramp, but only reserve the "then" part (because
// conditions are move outside this vectorized loop)
PrimExpr
ifexpr
=
op
->
args
[
0
];
PrimExpr
thenexpr
=
op
->
args
[
1
];
bool
flag
=
false
;
for
(
auto
&
cond
:
conditions_
)
{
if
(
ifexpr
.
get
()
==
cond
.
get
())
{
flag
=
true
;
}
}
if
(
flag
)
{
return
thenexpr
;
}
else
{
return
GetRef
<
PrimExpr
>
(
op
);
}
}
else
{
return
GetRef
<
PrimExpr
>
(
op
);
}
}
Var
inner_var_
;
int
vector_size_
;
std
::
vector
<
PrimExpr
>
conditions_
;
};
class
VectorizedConditionExtractor
:
public
StmtExprVisitor
{
public:
VectorizedConditionExtractor
()
=
default
;
std
::
vector
<
PrimExpr
>
GetConditions
(
const
Stmt
&
body
)
{
this
->
VisitStmt
(
body
);
return
conditions_
;
}
private:
void
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
op
->
op
.
same_as
(
builtin
::
if_then_else
()))
{
PrimExpr
cond
=
op
->
args
[
0
];
conditions_
.
emplace_back
(
cond
);
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitStmt_
(
const
IfThenElseNode
*
node
)
final
{
conditions_
.
emplace_back
(
node
->
condition
);
StmtExprVisitor
::
VisitStmt_
(
node
);
}
std
::
vector
<
PrimExpr
>
conditions_
;
};
// backward-compatibility: extracter -> extractor
using
VectorizedConditionExtracter
=
VectorizedConditionExtractor
;
class
NestedLoopChecker
:
public
StmtExprVisitor
{
public:
NestedLoopChecker
()
:
loop_num_
(
0
)
{}
int
GetNestLoopNum
(
const
Stmt
&
body
)
{
this
->
VisitStmt
(
body
);
return
loop_num_
;
}
private:
void
VisitStmt_
(
const
ForNode
*
node
)
final
{
loop_num_
++
;
StmtExprVisitor
::
VisitStmt_
(
node
);
}
int
loop_num_
;
};
// Modify every subexpression in the condition
class
VectorizedConditionMutator
:
public
StmtExprMutator
{
public:
VectorizedConditionMutator
(
Var
inner_var
,
int
extent
)
:
inner_var_
(
std
::
move
(
inner_var
)),
vector_size_
(
extent
)
{}
private:
PrimExpr
VisitExpr_
(
const
GENode
*
node
)
final
{
PrimExpr
lhs
=
StmtExprMutator
::
VisitExpr
(
node
->
a
);
PrimExpr
rhs
=
StmtExprMutator
::
VisitExpr
(
node
->
b
);
auto
span
=
node
->
span
;
Map
<
Var
,
PrimExpr
>
vmap_lhs
,
vmap_rhs
;
vmap_lhs
.
Set
(
inner_var_
,
0
);
PrimExpr
lhs_bound
=
Substitute
(
lhs
,
vmap_lhs
);
vmap_rhs
.
Set
(
inner_var_
,
vector_size_
-
1
);
PrimExpr
rhs_bound
=
Substitute
(
rhs
,
vmap_rhs
);
return
GE
(
lhs_bound
,
rhs_bound
,
span
);
}
PrimExpr
VisitExpr_
(
const
GTNode
*
node
)
final
{
PrimExpr
lhs
=
StmtExprMutator
::
VisitExpr
(
node
->
a
);
PrimExpr
rhs
=
StmtExprMutator
::
VisitExpr
(
node
->
b
);
auto
span
=
node
->
span
;
Map
<
Var
,
PrimExpr
>
vmap_lhs
,
vmap_rhs
;
vmap_lhs
.
Set
(
inner_var_
,
0
);
PrimExpr
lhs_bound
=
Substitute
(
lhs
,
vmap_lhs
);
vmap_rhs
.
Set
(
inner_var_
,
vector_size_
-
1
);
PrimExpr
rhs_bound
=
Substitute
(
rhs
,
vmap_rhs
);
return
GT
(
lhs_bound
,
rhs_bound
,
span
);
}
PrimExpr
VisitExpr_
(
const
LENode
*
node
)
final
{
PrimExpr
lhs
=
StmtExprMutator
::
VisitExpr
(
node
->
a
);
PrimExpr
rhs
=
StmtExprMutator
::
VisitExpr
(
node
->
b
);
auto
span
=
node
->
span
;
Map
<
Var
,
PrimExpr
>
vmap_lhs
,
vmap_rhs
;
vmap_lhs
.
Set
(
inner_var_
,
vector_size_
-
1
);
PrimExpr
lhs_bound
=
Substitute
(
lhs
,
vmap_lhs
);
vmap_rhs
.
Set
(
inner_var_
,
0
);
PrimExpr
rhs_bound
=
Substitute
(
rhs
,
vmap_rhs
);
return
LE
(
lhs_bound
,
rhs_bound
,
span
);
}
PrimExpr
VisitExpr_
(
const
LTNode
*
node
)
final
{
PrimExpr
lhs
=
StmtExprMutator
::
VisitExpr
(
node
->
a
);
PrimExpr
rhs
=
StmtExprMutator
::
VisitExpr
(
node
->
b
);
auto
span
=
node
->
span
;
Map
<
Var
,
PrimExpr
>
vmap_lhs
,
vmap_rhs
;
vmap_lhs
.
Set
(
inner_var_
,
vector_size_
-
1
);
PrimExpr
lhs_bound
=
Substitute
(
lhs
,
vmap_lhs
);
vmap_rhs
.
Set
(
inner_var_
,
0
);
PrimExpr
rhs_bound
=
Substitute
(
rhs
,
vmap_rhs
);
return
LT
(
lhs_bound
,
rhs_bound
,
span
);
}
Var
inner_var_
;
int
vector_size_
;
};
class
VectorizeRewriterDynamic
:
public
StmtExprMutator
{
public:
VectorizeRewriterDynamic
(
const
VectorizePlanResult
&
plan
,
bool
disable_dynamic_tail_split
)
:
vector_size_
(
plan
.
vector_size
),
condition_
(
plan
.
condition
),
dynamic_
(
plan
.
dynamic
),
disable_dynamic_tail_split_
(
disable_dynamic_tail_split
)
{}
private:
Stmt
VisitStmt_
(
const
ForNode
*
node
)
final
{
// Get pass config `tl.disable_dynamic_tail_split`
tvm
::
transform
::
PassContext
ctxt
=
tvm
::
transform
::
PassContext
::
Current
();
Optional
<
Bool
>
opt_disable_dynamic_tail_split
=
ctxt
->
GetConfig
(
kDisableDynamicTailSplit
,
Optional
<
Bool
>
());
bool
disable_dynamic_tail_split
=
opt_disable_dynamic_tail_split
.
value_or
(
Bool
(
false
));
inner_for_
=
node
;
auto
ret
=
StmtExprMutator
::
VisitStmt_
(
node
);
if
(
inner_for_
!=
node
)
{
return
ret
;
}
For
fnode
=
ret
.
as
<
For
>
().
value
();
auto
old_var
=
fnode
->
loop_var
;
if
(
!
fnode
->
extent
.
as
<
IntImmNode
>
())
{
return
ret
;
}
int
extent
=
Downcast
<
IntImm
>
(
fnode
->
extent
)
->
value
;
if
(
!
dynamic_
)
{
return
fnode
;
}
if
(
!
disable_dynamic_tail_split
)
{
// To handle the fact that cp.async only support address aligned with
// access size
vector_size_
=
1
;
}
ICHECK
(
extent
%
vector_size_
==
0
)
<<
"extent: "
<<
extent
<<
" vector_size_: "
<<
vector_size_
;
ICHECK
(
is_zero
(
fnode
->
min
));
Var
inner_var
=
Var
(
"vec"
);
Var
outer_var
=
Var
(
old_var
->
name_hint
);
Map
<
Var
,
PrimExpr
>
vmap
;
vmap
.
Set
(
fnode
->
loop_var
,
outer_var
*
vector_size_
+
inner_var
);
Stmt
body
=
Substitute
(
fnode
->
body
,
vmap
);
VectorizedConditionExtractor
extractor
;
std
::
vector
<
PrimExpr
>
conditions
=
extractor
.
GetConditions
(
body
);
VectorizedConditionMutator
condition_mutator
(
inner_var
,
vector_size_
);
// Adaptively set vectorized variable to the min/max value of the extent
PrimExpr
condition_bound
;
if
(
!
conditions
.
empty
())
{
condition_bound
=
condition_mutator
(
conditions
[
0
]);
for
(
int
i
=
1
;
i
<
conditions
.
size
();
++
i
)
{
condition_bound
=
condition_bound
&&
condition_mutator
(
conditions
[
i
]);
}
}
if
(
!
disable_dynamic_tail_split
)
{
// If dynamic_tail_split is true, we will vectorize the loop with
// if-then-else conditions modify body in the vectorized loop
VectorizedBodyMutator
mutator
(
inner_var
,
vector_size_
,
conditions
);
Stmt
vectorize_body
=
mutator
(
body
);
// add condition ifthenelse here
For
vectorize_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
vectorize_body
);
For
serial_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kSerial
,
body
);
if
(
!
conditions
.
empty
())
{
body
=
IfThenElse
(
condition_bound
,
vectorize_for
,
serial_for
);
}
else
{
body
=
vectorize_for
;
}
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
body
,
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
return
body
;
}
else
{
// If dynamic_tail_split is false, we will directly vectorize the loop
// without dynamic tail split and if_then_else, which may lead to error
VectorizedBodyMutator
mutator
(
inner_var
,
vector_size_
,
conditions
);
Stmt
vectorize_body
=
mutator
(
body
);
For
vectorize_for
=
For
(
inner_var
,
0
,
vector_size_
,
ForKind
::
kVectorized
,
vectorize_body
);
body
=
For
(
outer_var
,
0
,
extent
/
vector_size_
,
fnode
->
kind
,
vectorize_for
,
fnode
->
thread_binding
,
fnode
->
annotations
,
fnode
->
span
);
return
body
;
}
}
const
ForNode
*
inner_for_
{};
int
vector_size_
;
const
PrimExpr
condition_
;
const
bool
dynamic_
;
const
bool
disable_dynamic_tail_split_
;
};
VectorizePlanResult
GetVectorizePlanResultDynamic
(
const
For
&
loop
,
int
dynamic_alignment
,
bool
disable_dynamic_tail_split
)
{
VectorizePlannerDynamic
planner
(
dynamic_alignment
,
disable_dynamic_tail_split
);
int
vector_size
=
planner
.
Plan
(
loop
);
bool
dynamic
=
planner
.
GetDynamic
();
PrimExpr
condition
=
planner
.
GetCondition
();
return
{
vector_size
,
dynamic
,
condition
};
}
class
LoopVectorizerDynamic
:
public
IRMutatorWithAnalyzer
{
public:
static
Stmt
Substitute
(
Stmt
stmt
,
bool
disable_dynamic_tail_split
,
int
dynamic_alignment
)
{
arith
::
Analyzer
analyzer
;
LoopVectorizerDynamic
substituter
(
&
analyzer
,
disable_dynamic_tail_split
,
dynamic_alignment
);
stmt
=
substituter
.
VisitStmt
(
stmt
);
return
stmt
;
}
private:
LoopVectorizerDynamic
(
arith
::
Analyzer
*
analyzer
,
bool
disable_dynamic_tail_split
,
int
dynamic_alignment
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
),
disable_dynamic_tail_split_
(
disable_dynamic_tail_split
),
dynamic_alignment_
(
dynamic_alignment
)
{}
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
VectorizePlanResult
res
{
vector_load_bits_max_
,
false
,
0
};
res
=
GetVectorizePlanResultDynamic
(
for_node
,
dynamic_alignment_
,
disable_dynamic_tail_split_
);
NestedLoopChecker
checker
;
int
nest_num
=
checker
.
GetNestLoopNum
(
for_node
);
if
(
nest_num
>
1
||
for_node
->
kind
==
ForKind
::
kVectorized
)
{
// only rewrite the innermost
// non-vectorized loop
return
for_node
;
}
auto
rewriter
=
VectorizeRewriterDynamic
(
res
,
disable_dynamic_tail_split_
);
return
Downcast
<
For
>
(
rewriter
(
for_node
));
}
const
int
vector_load_bits_max_
=
128
;
int
dynamic_alignment_
;
bool
disable_dynamic_tail_split_
;
};
class
VectorizeSkipperDynamic
:
public
StmtMutator
{
public:
Stmt
VisitStmt_
(
const
ForNode
*
op
)
final
{
Stmt
stmt
=
StmtMutator
::
VisitStmt_
(
op
);
op
=
stmt
.
as
<
ForNode
>
();
if
(
op
->
kind
==
ForKind
::
kVectorized
)
{
return
For
(
op
->
loop_var
,
op
->
min
,
op
->
extent
,
ForKind
::
kSerial
,
op
->
body
);
}
else
{
return
stmt
;
}
}
};
tvm
::
transform
::
Pass
LoopVectorizeDynamic
()
{
using
namespace
tir
::
transform
;
auto
pass_func
=
[
=
](
PrimFunc
f
,
const
IRModule
&
m
,
PassContext
ctx
)
{
bool
disable_dynamic_tail_split
=
ctx
->
GetConfig
<
Bool
>
(
kDisableDynamicTailSplit
,
Bool
(
true
)).
value
();
int
dynamic_alignment
=
(
int
)(
ctx
->
GetConfig
<
Integer
>
(
kDynamicAlignment
,
Integer
(
8
))
.
value_or
(
Integer
(
8
))
->
value
);
// Ensure tl.dynamic_alignment is a power of 2
if
(
disable_dynamic_tail_split
&&
((
dynamic_alignment
&
(
dynamic_alignment
-
1
))
!=
0
))
{
LOG
(
FATAL
)
<<
"tl.dynamic_alignment must be a power of 2, but got "
<<
dynamic_alignment
;
}
auto
*
n
=
f
.
CopyOnWrite
();
n
->
body
=
LoopVectorizerDynamic
::
Substitute
(
std
::
move
(
n
->
body
),
disable_dynamic_tail_split
,
dynamic_alignment
);
return
f
;
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LoopVectorizeDynamic"
,
{});
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK
({
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LoopVectorizeDynamic"
,
LoopVectorizeDynamic
);
});
}
// namespace tl
}
// namespace tvm
src/transform/lower_device_kernel_launch.cc
View file @
bbbf4207
...
...
@@ -36,7 +36,7 @@ namespace tvm {
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
namespace
{
struct
KernelInfo
{
// The device on which the PrimFunc runs
...
...
@@ -372,8 +372,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule
updates
;
for
(
const
auto
&
[
gvar
,
base_func
]
:
mod
->
functions
)
{
if
(
auto
*
ptr
=
base_func
.
as
<
PrimFuncNode
>
())
{
auto
prim_func
=
mutator
.
RewriteKernelLaunchSite
(
gvar
,
GetRef
<
PrimFunc
>
(
ptr
));
auto
prim_func
=
mutator
.
RewriteKernelLaunchSite
(
gvar
,
tvm
::
ffi
::
GetRef
<
PrimFunc
>
(
ptr
));
if
(
!
prim_func
.
same_as
(
base_func
))
{
updates
->
Add
(
gvar
,
prim_func
);
}
...
...
@@ -388,8 +388,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule
updates
;
for
(
const
auto
&
[
gvar
,
base_func
]
:
mod
->
functions
)
{
if
(
auto
*
ptr
=
base_func
.
as
<
PrimFuncNode
>
())
{
auto
prim_func
=
mutator
.
UpdateKernelAttributes
(
gvar
,
GetRef
<
PrimFunc
>
(
ptr
));
auto
prim_func
=
mutator
.
UpdateKernelAttributes
(
gvar
,
tvm
::
ffi
::
GetRef
<
PrimFunc
>
(
ptr
));
if
(
!
prim_func
.
same_as
(
base_func
))
{
updates
->
Add
(
gvar
,
prim_func
);
}
...
...
@@ -407,11 +407,11 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
"tl.LowerDeviceKernelLaunch"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerDeviceKernelLaunch"
,
LowerDeviceKernelLaunch
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/lower_device_storage_access_info.cc
View file @
bbbf4207
...
...
@@ -45,7 +45,7 @@ public:
Stmt
VisitStmt_
(
const
AllocateNode
*
op
)
final
{
auto
scope
=
StorageScope
::
Create
(
GetPtrStorageScope
(
op
->
buffer_var
));
if
(
!
scope
.
tag
.
empty
()
&&
scope
.
tag
!=
".dyn"
&&
scope
.
tag
!=
".var"
&&
scope
.
tag
!=
".barrier"
&&
scope
.
tag
!=
".descriptor"
)
{
scope
.
tag
!=
".barrier"
&&
scope
.
tag
.
find
(
".descriptor"
)
!=
0
)
{
auto
info
=
GetMemoryInfo
(
GetPtrStorageScope
(
op
->
buffer_var
));
ICHECK
(
info
.
defined
())
<<
"Cannot find memory info of "
<<
scope
.
to_string
();
...
...
@@ -143,11 +143,11 @@ Pass LowerDeviceStorageAccessInfo() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerDeviceStorageAccessInfo"
,
LowerDeviceStorageAccessInfo
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/lower_hopper_intrin.cc
View file @
bbbf4207
...
...
@@ -113,14 +113,14 @@ public:
if
(
call
->
op
.
same_as
(
create_tma_descriptor
())
||
call
->
op
.
same_as
(
create_tma_im2col_descriptor
()))
{
Var
var
;
auto
iter
=
desc_map_
.
find
(
GetRef
<
Call
>
(
call
));
auto
iter
=
desc_map_
.
find
(
tvm
::
ffi
::
GetRef
<
Call
>
(
call
));
if
(
iter
!=
desc_map_
.
end
())
{
var
=
iter
->
second
;
}
else
{
String
name
=
call
->
args
[
2
].
as
<
Var
>
().
value
()
->
name_hint
;
var
=
Var
(
name
+
"_desc"
,
PointerType
(
PrimType
(
cuTensorMapType
()),
"grid_constant"
));
desc_map_
[
GetRef
<
Call
>
(
call
)]
=
var
;
desc_map_
[
tvm
::
ffi
::
GetRef
<
Call
>
(
call
)]
=
var
;
prefetch_calls_
.
push_back
(
Evaluate
(
Call
(
DataType
::
Handle
(),
builtin
::
call_extern
(),
{
StringImm
(
"tl::prefetch_tma_descriptor"
),
var
})));
...
...
@@ -161,10 +161,10 @@ tvm::transform::Pass LowerHopperIntrin() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerHopperIntrin"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerHopperIntrin"
,
LowerHopperIntrin
);
}
);
}
#endif // (CUDA_MAJOR_VERSION >= 12)
}
// namespace tl
...
...
src/transform/lower_intrin.cc
View file @
bbbf4207
...
...
@@ -37,6 +37,7 @@
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
class
IntrinInjecter
:
public
tvm
::
arith
::
IRMutatorWithAnalyzer
{
public:
...
...
@@ -70,9 +71,9 @@ public:
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
if
(
auto
*
ptr_op
=
op
->
op
.
as
<
OpNode
>
())
{
for
(
const
auto
&
f_attr_map
:
attr_maps_
)
{
FLowerGeneral
f
=
f_attr_map
.
get
(
GetRef
<
Op
>
(
ptr_op
),
nullptr
);
FLowerGeneral
f
=
f_attr_map
.
get
(
tvm
::
ffi
::
GetRef
<
Op
>
(
ptr_op
),
nullptr
);
if
(
f
!=
nullptr
)
{
PrimExpr
e
=
GetRef
<
PrimExpr
>
(
op
);
PrimExpr
e
=
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
PrimExpr
r
=
f
(
e
);
ICHECK
(
r
.
defined
())
<<
"intrinsic rule must always return valid Expr"
;
if
(
!
r
.
same_as
(
e
))
{
...
...
@@ -99,7 +100,7 @@ public:
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
PrimExpr
VisitExpr_
(
const
FloorDivNode
*
op
)
final
{
auto
e
=
GetRef
<
PrimExpr
>
(
op
);
auto
e
=
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
PrimExpr
ret
=
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
);
op
=
ret
.
as
<
FloorDivNode
>
();
if
(
op
==
nullptr
)
...
...
@@ -305,7 +306,7 @@ public:
using
namespace
arith
;
PVar
<
PrimExpr
>
x
,
y
;
PVar
<
IntImm
>
c
;
auto
e
=
GetRef
<
PrimExpr
>
(
op
);
auto
e
=
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
if
(
max
(
floordiv
(
x
,
y
),
c
).
Match
(
e
)
&&
c
.
Eval
()
->
value
>=
0
&&
analyzer_
->
CanProveGreaterEqual
(
y
.
Eval
(),
0
))
{
return
max
(
VisitExpr
(
truncdiv
(
x
,
y
).
Eval
()),
c
.
Eval
());
...
...
@@ -316,7 +317,7 @@ public:
PrimExpr
VisitExpr_
(
const
EQNode
*
op
)
final
{
using
namespace
arith
;
PVar
<
PrimExpr
>
x
,
y
;
auto
e
=
GetRef
<
PrimExpr
>
(
op
);
auto
e
=
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
if
((
floormod
(
x
,
y
)
==
0
).
Match
(
e
))
{
return
VisitExpr
((
truncmod
(
x
,
y
)
==
0
).
Eval
());
}
...
...
@@ -326,7 +327,7 @@ public:
PrimExpr
VisitExpr_
(
const
NENode
*
op
)
final
{
using
namespace
arith
;
PVar
<
PrimExpr
>
x
,
y
;
auto
e
=
GetRef
<
PrimExpr
>
(
op
);
auto
e
=
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
if
((
floormod
(
x
,
y
)
!=
0
).
Match
(
e
))
{
return
VisitExpr
((
truncmod
(
x
,
y
)
!=
0
).
Eval
());
}
...
...
@@ -413,10 +414,10 @@ tir::transform::Pass LowerIntrin() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerIntrin"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerIntrin"
,
LowerIntrin
);
}
);
}
}
// namespace transform
...
...
src/transform/lower_l2_persistent_annotation.cc
View file @
bbbf4207
...
...
@@ -98,10 +98,10 @@ tvm::transform::Pass LowerL2Persistent() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerL2Persistent"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerL2Persistent"
,
LowerL2Persistent
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/lower_opaque_block.cc
View file @
bbbf4207
...
...
@@ -119,7 +119,7 @@ private:
// Step 1. Update unit loop info.
PrimExpr
min
=
this
->
VisitExpr
(
op
->
min
);
PrimExpr
extent
=
this
->
VisitExpr
(
op
->
extent
);
if
(
is_one
(
extent
)
&&
op
->
annotations
.
empty
(
))
{
if
(
is_one
(
extent
)
&&
IsEffectivelyEmptyAnnotation
(
op
->
annotations
))
{
// handling unit loop
unit_loop_vars_
[
op
->
loop_var
]
=
min
;
}
...
...
@@ -135,7 +135,8 @@ private:
ICHECK
(
op
->
thread_binding
.
defined
());
String
thread_tag
=
op
->
thread_binding
.
value
()
->
thread_tag
;
body
=
MakeLaunchThread
(
min
,
extent
,
op
->
loop_var
,
thread_tag
,
body
);
}
else
if
(
is_one
(
extent
)
&&
op
->
annotations
.
empty
())
{
}
else
if
(
is_one
(
extent
)
&&
IsEffectivelyEmptyAnnotation
(
op
->
annotations
))
{
// Case 2. Unit loop
return
body
;
}
else
{
...
...
@@ -150,8 +151,25 @@ private:
return
body
;
}
// Treat annotations as empty if they are truly empty or contain only
// the unroll hint `pragma_unroll_explicit`. This allows unit-length
// loops produced by unroll pragmas to be simplified away.
bool
IsEffectivelyEmptyAnnotation
(
const
Map
<
String
,
ffi
::
Any
>
&
annotations
)
const
{
if
(
annotations
.
empty
())
{
return
true
;
}
if
(
annotations
.
size
()
==
1
)
{
auto
it
=
annotations
.
find
(
tir
::
attr
::
pragma_unroll_explicit
);
if
(
it
!=
annotations
.
end
())
{
return
true
;
}
}
return
false
;
}
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
Var
var
=
GetRef
<
Var
>
(
op
);
Var
var
=
tvm
::
ffi
::
GetRef
<
Var
>
(
op
);
auto
it
=
unit_loop_vars_
.
find
(
var
);
if
(
it
==
unit_loop_vars_
.
end
())
{
return
var
;
...
...
@@ -286,10 +304,10 @@ tir::transform::Pass LowerOpaqueBlock() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerOpaqueBlock"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerOpaqueBlock"
,
LowerOpaqueBlock
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/lower_shared_barrier.cc
View file @
bbbf4207
...
...
@@ -32,7 +32,7 @@ private:
:
disable_shuffle_elect_
(
disable_shuffle_elect
)
{}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Block
block
=
GetRef
<
Block
>
(
op
);
Block
block
=
tvm
::
ffi
::
GetRef
<
Block
>
(
op
);
Array
<
Buffer
>
alloc_buffers
=
op
->
alloc_buffers
;
// Record the mapping from buffer data var to buffer for later lookup
...
...
@@ -204,10 +204,10 @@ tvm::transform::Pass LowerSharedBarrier() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerSharedBarrier"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerSharedBarrier"
,
LowerSharedBarrier
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/lower_shared_tmem.cc
View file @
bbbf4207
...
...
@@ -30,7 +30,7 @@ public:
private:
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
Block
block
=
GetRef
<
Block
>
(
op
);
Block
block
=
tvm
::
ffi
::
GetRef
<
Block
>
(
op
);
Array
<
Buffer
>
alloc_buffers
=
op
->
alloc_buffers
;
if
(
op
->
annotations
.
count
(
attr
::
kLayoutMap
))
{
auto
layout_map
=
op
->
annotations
.
Get
(
attr
::
kLayoutMap
);
...
...
@@ -88,6 +88,8 @@ private:
Array
<
Var
>
new_data_vars
;
for
(
auto
buffer
:
tmem_buffers
)
{
auto
data
=
buffer
->
data
;
if
(
var_remap_
.
count
(
data
))
continue
;
auto
new_data
=
Var
(
data
->
name_hint
,
PointerType
(
PrimType
(
tmem_dtype_
),
"shared"
));
var_remap_
.
Set
(
data
,
new_data
);
...
...
@@ -107,6 +109,7 @@ private:
buffer
->
buffer_type
);
new_buffers
.
push_back
(
new_buffer
);
buffer_remap_
.
Set
(
buffer
,
new_buffer
);
buffer_data_to_buffer_
.
Set
(
new_data
,
new_buffer
);
}
// remove the tmem buffers
...
...
@@ -255,7 +258,15 @@ private:
op
->
dtype
,
op
->
op
,
{
op
->
args
[
0
],
new_data
,
op
->
args
[
2
],
op
->
args
[
3
],
op
->
args
[
4
]});
}
return
StmtExprMutator
::
VisitExpr_
(
op
);
auto
expr
=
StmtExprMutator
::
VisitExpr_
(
op
);
return
expr
;
}
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
Var
var
=
tvm
::
ffi
::
GetRef
<
Var
>
(
op
);
if
(
var_remap_
.
count
(
var
))
{
return
var_remap_
[
var
];
}
return
var
;
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
...
...
@@ -300,10 +311,10 @@ tvm::transform::Pass LowerSharedTmem() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerSharedTmem"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerSharedTmem"
,
LowerSharedTmem
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/lower_thread_allreduce.cc
View file @
bbbf4207
...
...
@@ -39,6 +39,7 @@
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
using
runtime
::
StorageRank
;
using
runtime
::
StorageScope
;
...
...
@@ -944,11 +945,11 @@ tvm::transform::Pass LowerThreadAllreduce() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerThreadAllreduce"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerThreadAllreduce"
,
LowerThreadAllreduce
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/lower_tile_op.cc
100755 → 100644
View file @
bbbf4207
...
...
@@ -10,6 +10,7 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <unordered_map>
#include <vector>
#include "../layout/layout.h"
#include "../layout/utils.h"
...
...
@@ -103,55 +104,6 @@ private:
Map
<
Buffer
,
Layout
>
layout_remap_
;
};
class
BufferGemmCollector
:
public
StmtExprVisitor
{
public:
BufferGemmCollector
()
{
Clear
();
}
void
Clear
()
{
buffer_var_gemm_
.
clear
();
}
void
Collect
(
const
Stmt
&
stmt
)
{
VisitStmt
(
stmt
);
}
Array
<
Var
>
GetBufferVarGemm
()
{
return
buffer_var_gemm_
;
}
private:
void
VisitStmt_
(
const
EvaluateNode
*
op
)
{
const
CallNode
*
call_node
=
op
->
value
.
as
<
CallNode
>
();
// Value of EvaluateNode may not be a call
if
(
!
call_node
)
{
return
;
}
auto
call
=
Downcast
<
Call
>
(
call_node
);
if
(
call
->
op
.
same_as
(
Gemm
::
Get
()))
{
auto
srcA_buffer_access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
0
]);
ICHECK
(
srcA_buffer_access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
srcA_buffer_var
=
Downcast
<
Var
>
(
srcA_buffer_access_ptr
->
args
[
1
]);
auto
srcB_buffer_access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
1
]);
ICHECK
(
srcB_buffer_access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
srcB_buffer_var
=
Downcast
<
Var
>
(
srcB_buffer_access_ptr
->
args
[
1
]);
auto
dst_buffer_access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
ICHECK
(
dst_buffer_access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
dst_buffer_var
=
Downcast
<
Var
>
(
dst_buffer_access_ptr
->
args
[
1
]);
buffer_var_gemm_
.
push_back
(
srcA_buffer_var
);
buffer_var_gemm_
.
push_back
(
srcB_buffer_var
);
buffer_var_gemm_
.
push_back
(
dst_buffer_var
);
}
else
if
(
call
->
op
.
same_as
(
GemmSP
::
Get
()))
{
auto
srcA_buffer_access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
0
]);
ICHECK
(
srcA_buffer_access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
srcA_buffer_var
=
Downcast
<
Var
>
(
srcA_buffer_access_ptr
->
args
[
1
]);
auto
srcB_buffer_access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
1
]);
ICHECK
(
srcB_buffer_access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
srcB_buffer_var
=
Downcast
<
Var
>
(
srcB_buffer_access_ptr
->
args
[
1
]);
auto
dst_buffer_access_ptr
=
Downcast
<
Call
>
(
call
->
args
[
2
]);
ICHECK
(
dst_buffer_access_ptr
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
dst_buffer_var
=
Downcast
<
Var
>
(
dst_buffer_access_ptr
->
args
[
1
]);
buffer_var_gemm_
.
push_back
(
srcA_buffer_var
);
buffer_var_gemm_
.
push_back
(
srcB_buffer_var
);
buffer_var_gemm_
.
push_back
(
dst_buffer_var
);
}
}
Array
<
Var
>
buffer_var_gemm_
;
};
/*!
* \brief A class that rewrites buffer references in a statement based on a
...
...
@@ -253,11 +205,6 @@ public:
auto
target
=
f
->
GetAttr
<
Target
>
(
tvm
::
attr
::
kTarget
);
ICHECK
(
target
.
defined
())
<<
"LowerTileOpPass: Require the target attribute"
;
substituter
.
target_
=
target
.
value
();
// For TMA 1D, we should collect the buffers which are not used in GEMM and
// do not need swizzle
BufferGemmCollector
collector
;
collector
.
Collect
(
f
->
body
);
substituter
.
buffer_var_gemm_
=
collector
.
GetBufferVarGemm
();
PrimFuncNode
*
fptr
=
f
.
CopyOnWrite
();
fptr
->
body
=
substituter
.
VisitStmt
(
f
->
body
);
fptr
->
body
=
...
...
@@ -301,6 +248,9 @@ private:
layout_map_
.
Set
(
buffer
,
layout
);
}
}
// Begin a new workspace collection frame for this block scope
workspace_stack_
.
emplace_back
();
auto
block
=
Downcast
<
Block
>
(
arith
::
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
auto
block_ptr
=
block
.
CopyOnWrite
();
for
(
size_t
i
=
0
;
i
<
block
->
alloc_buffers
.
size
();
i
++
)
{
...
...
@@ -309,9 +259,13 @@ private:
block_ptr
->
alloc_buffers
.
Set
(
i
,
buffer_remap_
[
buffer
]);
}
}
for
(
const
auto
&
buffer
:
workspaces_
)
// Attach any workspaces requested within this block to its alloc_buffers
if
(
!
workspace_stack_
.
empty
())
{
for
(
const
auto
&
buffer
:
workspace_stack_
.
back
())
{
block_ptr
->
alloc_buffers
.
push_back
(
buffer
);
workspaces_
.
clear
();
}
workspace_stack_
.
pop_back
();
}
return
block
;
}
...
...
@@ -435,7 +389,7 @@ private:
return
expr
;
}
if
(
const
auto
*
var_node
=
expr
.
as
<
VarNode
>
())
{
Var
var
=
GetRef
<
Var
>
(
var_node
);
Var
var
=
tvm
::
ffi
::
GetRef
<
Var
>
(
var_node
);
auto
it
=
let_bindings_
.
find
(
var
);
if
(
it
!=
let_bindings_
.
end
())
{
return
it
->
second
;
...
...
@@ -611,7 +565,7 @@ private:
let_bindings_
.
erase
(
op
->
var
);
}
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
auto
n
=
this
->
CopyOnWrite
(
op
);
n
->
value
=
value
;
...
...
@@ -652,13 +606,22 @@ private:
if
(
call
&&
call
->
op
.
as
<
GlobalVarNode
>
())
return
Downcast
<
Evaluate
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
auto
tile_op
=
ParseOperator
(
GetRef
<
Stmt
>
(
op
),
buffer_data_to_buffer_
);
auto
tile_op
=
ParseOperator
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
),
buffer_data_to_buffer_
);
if
(
!
tile_op
.
defined
())
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
AddWorkspaceCallback
callback
=
[
this
](
int
num_elem
,
DataType
dtype
)
{
auto
workspace
=
decl_buffer
({
PrimExpr
(
num_elem
)},
dtype
,
"workspace"
,
"shared.dyn"
);
workspaces_
.
push_back
(
workspace
);
// Record workspace under the innermost block scope so its lifetime
// covers the statements that requested it and does not sink into
// subsequently created inner blocks (e.g., GEMM macro blocks).
if
(
!
workspace_stack_
.
empty
())
{
workspace_stack_
.
back
().
push_back
(
workspace
);
}
else
{
// Fallback: create a temporary frame (should be rare)
workspace_stack_
.
emplace_back
(
Array
<
Buffer
>
{
workspace
});
}
return
workspace
.
access_ptr
(
2
);
// write
};
...
...
@@ -676,9 +639,9 @@ private:
thread_bounds
=
Range
::
FromMinExtent
(
0
,
1
);
}
auto
lowered
=
tile_op
->
Lower
(
LowerArgs
{
target_
,
thread_bounds
,
thread_var_
->
var
,
callback
,
layout_map_
,
buffer_remap_
,
buffer_var_gemm_
},
auto
lowered
=
tile_op
->
Lower
(
LowerArgs
{
target_
,
thread_bounds
,
thread_var_
->
var
,
callback
,
layout_map_
,
buffer_remap_
},
analyzer_
);
return
IRMutatorWithAnalyzer
::
VisitStmt
(
lowered
);
}
...
...
@@ -706,7 +669,8 @@ private:
IterVar
thread_var_
=
IterVar
(
Range
::
FromMinExtent
(
0
,
1
),
Var
(
"v_thread"
),
IterVarType
::
kDataPar
);
size_t
thread_block_size_
=
0
;
Array
<
Buffer
>
workspaces_
;
// Stack of per-Block workspace buffers gathered while visiting children
std
::
vector
<
Array
<
Buffer
>>
workspace_stack_
;
// For ptx Node, we need to remap the buffer and indices
// By access CallNode instead of BufferLoad Node.
bool
is_ptx_
{
false
};
...
...
@@ -716,7 +680,6 @@ private:
std
::
unordered_map
<
Var
,
Buffer
,
ObjectPtrHash
,
ObjectPtrEqual
>
buffer_map_
;
Map
<
Var
,
Var
>
var_remap_
;
bool
has_tma_
{
false
};
Array
<
Var
>
buffer_var_gemm_
;
};
namespace
transform
{
...
...
@@ -730,10 +693,10 @@ tvm::transform::Pass LowerTileOp() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LowerTileOp"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LowerTileOp"
,
LowerTileOp
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/make_packed_api.cc
View file @
bbbf4207
...
...
@@ -42,6 +42,7 @@
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
static
constexpr
const
char
*
kDeviceContextVar
=
"device_api_context"
;
namespace
{
...
...
@@ -168,7 +169,7 @@ private:
auto
node
=
Downcast
<
Call
>
(
StmtExprMutator
::
VisitExpr_
(
op
));
if
(
auto
*
gvar_ptr
=
node
->
op
.
as
<
GlobalVarNode
>
())
{
auto
gvar
=
GetRef
<
GlobalVar
>
(
gvar_ptr
);
auto
gvar
=
tvm
::
ffi
::
GetRef
<
GlobalVar
>
(
gvar_ptr
);
if
(
auto
symbol
=
packed_func_methods
.
Get
(
gvar
))
{
Array
<
PrimExpr
>
cpacked_args
;
cpacked_args
.
push_back
(
tir
::
StringImm
(
symbol
.
value
()));
...
...
@@ -220,7 +221,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
// Internal function calls do not need the PackedFunc API
auto
global_symbol
=
func
->
GetAttr
<
String
>
(
tvm
::
attr
::
kGlobalSymbol
);
if
(
!
global_symbol
.
defined
()
)
{
if
(
!
global_symbol
)
{
return
std
::
nullopt
;
}
...
...
@@ -229,7 +230,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
PrimFunc
MakePackedAPI
(
PrimFunc
func
)
{
auto
global_symbol
=
RequiresPackedAPI
(
func
);
if
(
!
global_symbol
.
defined
()
)
{
if
(
!
global_symbol
)
{
return
func
;
}
std
::
string
name_hint
=
global_symbol
.
value
();
...
...
@@ -406,7 +407,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
StringImm
(
name_hint
+
"_compute_"
),
body
);
// Set device context
if
(
vmap
.
count
(
device_id
.
get
()))
{
ObjectRef
node
=
String
(
"default"
);
auto
node
=
String
(
"default"
);
seq_check
.
push_back
(
AttrStmt
(
node
,
tir
::
attr
::
device_id
,
device_id
,
nop
));
seq_check
.
push_back
(
AttrStmt
(
node
,
tir
::
attr
::
device_type
,
device_type
,
nop
));
...
...
@@ -432,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
auto
shape_vectorize_expr
=
[
&
]()
->
PrimExpr
{
PrimExpr
result
=
IntImm
(
kv
.
second
->
DefaultIndexType
(),
1
);
result
=
result
*
vectorize_dim
;
result
=
FloorMod
(
result
,
dynamic_alignment
);
result
=
FloorMod
(
result
,
IntImm
(
result
->
dtype
,
dynamic_alignment
)
)
;
return
result
;
}();
shape_checks
.
emplace_back
(
AssertStmt
(
...
...
@@ -513,11 +514,11 @@ tvm::transform::Pass MakePackedAPI() {
return
tvm
::
transform
::
CreateModulePass
(
pass_func
,
0
,
"tl.MakePackedAPI"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.MakePackedAPI"
,
[]()
{
return
MakePackedAPI
();
});
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/merge_if_stmt.cc
View file @
bbbf4207
...
...
@@ -98,10 +98,10 @@ tvm::transform::Pass MergeIfStmt() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.MergeIfStmt"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.MergeIfStmt"
,
MergeIfStmt
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/merge_shared_memory_allocations.cc
View file @
bbbf4207
...
...
@@ -31,6 +31,12 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <functional>
#include <limits>
#include <optional>
#include <queue>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <utility>
...
...
@@ -38,7 +44,6 @@
#include "../op/builtin.h"
#include "../target/utils.h"
#include "runtime/thread_storage_scope.h"
#include "support/arena.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/tir/function.h"
...
...
@@ -141,6 +146,8 @@ public:
void
VisitStmt_
(
const
AllocateNode
*
op
)
final
{
size_t
level
=
scope_
.
size
();
const
VarNode
*
buf
=
op
->
buffer_var
.
get
();
// Record the allocation site and depth so liveness can reason about the
// original scope.
alloc_info_
[
buf
].
alloc
=
op
;
alloc_info_
[
buf
].
level
=
level
;
StmtExprVisitor
::
VisitStmt_
(
op
);
...
...
@@ -155,7 +162,7 @@ public:
auto
it
=
alloc_info_
.
find
(
buf
);
if
(
it
!=
alloc_info_
.
end
()
&&
it
->
second
.
alloc
)
{
ICHECK_LT
(
it
->
second
.
level
,
scope_
.
size
());
if
(
IsAppropriateSharedMemory
(
GetRef
<
Var
>
(
buf
)))
{
if
(
IsAppropriateSharedMemory
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buf
)))
{
// set into scope_.size() - 1 for aggressive memory reuse
auto
enable_aggressive_merge
=
enable_aggressive_merge_
;
if
(
enable_aggressive_merge
)
{
...
...
@@ -194,17 +201,23 @@ public:
const
VarNode
*
buf
=
op
->
buffer
->
data
.
get
();
auto
it
=
alloc_info_
.
find
(
buf
);
if
(
it
!=
alloc_info_
.
end
()
&&
it
->
second
.
alloc
)
{
// Allow buffer access at the same level or deeper scope
// Changed from < to <= to handle cases where buffer is accessed
// in expressions at the same scope level where it's allocated
// Earlier we required `alloc_level < scope_.size()`, assuming every load
// would occur strictly inside a nested scope. In practice the lowering
// pipeline may materialise reads in the very same frame that owns the
// allocation (e.g. when the buffer value is passed directly to a call),
// which used to trigger the CHECK. Treat same-level accesses as valid so
// the merged allocator can reason about their lifetime correctly.
ICHECK_LE
(
it
->
second
.
level
,
scope_
.
size
())
<<
"Load memory in places other than store."
;
if
(
IsAppropriateSharedMemory
(
GetRef
<
Var
>
(
buf
)))
{
if
(
IsAppropriateSharedMemory
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buf
)))
{
auto
enable_aggressive_merge
=
enable_aggressive_merge_
;
if
(
enable_aggressive_merge
)
{
scope_
[
scope_
.
size
()
-
1
].
touched
.
push_back
(
buf
);
}
else
{
// When accessing at the same level, use that level
// When the access happens in the same scope frame as the allocation
// we attribute it to that frame instead of the outer parent. This
// keeps the liveness window tight while still accounting for nested
// scopes that legitimately touch the buffer deeper in the tree.
size_t
access_level
=
std
::
min
(
it
->
second
.
level
,
scope_
.
size
()
-
1
);
scope_
[
access_level
].
touched
.
push_back
(
buf
);
}
...
...
@@ -216,14 +229,17 @@ public:
// Directly reference to the variable count as a read.
auto
it
=
alloc_info_
.
find
(
buf
);
if
(
it
!=
alloc_info_
.
end
()
&&
it
->
second
.
alloc
)
{
// Allow buffer access at the same level or deeper scope
// Same rationale as the BufferLoad path above: direct references can be
// emitted at the allocation level after flattening, so accept them and
// record the touch for liveness planning.
ICHECK_LE
(
it
->
second
.
level
,
scope_
.
size
());
if
(
IsAppropriateSharedMemory
(
GetRef
<
Var
>
(
buf
)))
{
if
(
IsAppropriateSharedMemory
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buf
)))
{
auto
enable_aggressive_merge
=
enable_aggressive_merge_
;
if
(
enable_aggressive_merge
)
{
scope_
[
scope_
.
size
()
-
1
].
touched
.
push_back
(
buf
);
}
else
{
// When accessing at the same level, use that level
// Attribute same-level uses to the allocation frame, mirroring the
// BufferLoad handling to keep reuse decisions consistent.
size_t
access_level
=
std
::
min
(
it
->
second
.
level
,
scope_
.
size
()
-
1
);
scope_
[
access_level
].
touched
.
push_back
(
buf
);
}
...
...
@@ -245,6 +261,8 @@ public:
scope_
.
pop_back
();
int64_t
end_index
=
static_cast
<
int64_t
>
(
linear_seq_
.
size
());
ICHECK_GT
(
end_index
,
begin_index
);
// The paired entries serve as scope sentinels once we flatten the
// control-flow tree.
e
.
scope_pair_offset
=
begin_index
-
end_index
;
linear_seq_
.
push_back
(
e
);
// record the pointer to end index.
...
...
@@ -336,9 +354,30 @@ public:
}
private:
// Helper to record alignment for a shared/shared.dyn Var under alignment
// scope
void
MarkSharedVarIfNeeded
(
const
VarNode
*
op
)
{
if
(
!
op
||
!
under_alignment_scope_
)
return
;
auto
ptr_type
=
op
->
type_annotation
.
as
<
PointerTypeNode
>
();
if
(
!
ptr_type
)
return
;
auto
scope
=
GetPtrStorageScope
(
tvm
::
ffi
::
GetRef
<
Var
>
(
op
));
if
(
scope
==
"shared"
||
scope
==
"shared.dyn"
)
{
auto
target
=
Target
::
Current
();
ICHECK
(
target
.
defined
())
<<
"Target is not defined"
;
const
int
alignment
=
TargetIsHopper
(
target
)
?
1024
:
16
;
shmem_alignment_map_
[
op
]
=
alignment
;
}
}
void
VisitExpr_
(
const
CallNode
*
op
)
{
if
(
op
->
op
.
same_as
(
tl
::
tl_gemm
())
||
op
->
op
.
same_as
(
tl
::
tl_gemm_sp
())
||
op
->
op
.
same_as
(
tl
::
tma_load
())
||
op
->
op
.
same_as
(
tl
::
tma_store
()))
{
op
->
op
.
same_as
(
tl
::
tma_load
())
||
op
->
op
.
same_as
(
tl
::
tma_store
())
||
op
->
op
.
same_as
(
tl
::
initialize_wgmma_descriptor
())
||
op
->
op
.
same_as
(
tl
::
initialize_tcgen05_descriptor
()))
{
// These intrinsics introduce stricter SMEM alignment requirements; mark
// the subtree.
under_alignment_scope_
=
true
;
StmtExprVisitor
::
VisitExpr_
(
op
);
under_alignment_scope_
=
false
;
...
...
@@ -348,15 +387,16 @@ private:
}
void
VisitExpr_
(
const
VarNode
*
op
)
{
auto
ptr_type
=
op
->
type_annotation
.
as
<
PointerTypeNode
>
();
if
(
ptr_type
&&
under_alignment_scope_
)
{
auto
scope
=
GetPtrStorageScope
(
GetRef
<
Var
>
(
op
));
if
(
scope
==
"shared"
||
scope
==
"shared.dyn"
)
{
auto
target
=
Target
::
Current
();
ICHECK
(
target
.
defined
())
<<
"Target is not defined"
;
const
int
alignment
=
TargetIsHopper
(
target
)
?
1024
:
16
;
shmem_alignment_map_
[
op
]
=
alignment
;
MarkSharedVarIfNeeded
(
op
);
StmtExprVisitor
::
VisitExpr_
(
op
);
}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
{
// If we encounter address_of(BufferLoad(...)) or any direct BufferLoad
// within an alignment scope, make sure we mark the underlying shared var.
if
(
op
&&
under_alignment_scope_
)
{
const
VarNode
*
data_var
=
op
->
buffer
->
data
.
get
();
MarkSharedVarIfNeeded
(
data_var
);
}
StmtExprVisitor
::
VisitExpr_
(
op
);
}
...
...
@@ -394,6 +434,8 @@ public:
enable_aggressive_merge
,
verbose
);
finder
(
stmt
);
shmem_alignment_map_
=
SharedMemoryAlignmentPlanner
::
Plan
(
stmt
);
// First compute liveness over the flattened schedule, then feed it into the
// arena packer.
this
->
LivenessAnalysis
(
finder
.
linear_seq_
,
finder
.
stmt_attrs_
);
this
->
PlanMemory
(
finder
.
linear_seq_
,
finder
.
stmt_attrs_
);
}
...
...
@@ -403,65 +445,6 @@ private:
if
(
op
->
attr_key
==
tir
::
attr
::
thread_extent
&&
!
allocated_
)
{
// Allocate one dynamic shared memory allocation at the beginning of
// thread scope
int
max_layer_num
=
0
;
std
::
vector
<
const
StorageEntry
*>
all_entry
;
for
(
const
auto
&
e
:
const_free_map_
)
{
all_entry
.
push_back
(
e
.
second
);
}
for
(
const
StorageEntry
*
e
:
sym_free_list_
)
{
all_entry
.
push_back
(
e
);
}
// Sort the storage entries in descending order of their total allocation
// size (in bits). This ensures that larger allocations are placed first,
// which can help minimize fragmentation and improve memory packing
// efficiency when merging shared memory buffers.
std
::
sort
(
all_entry
.
begin
(),
all_entry
.
end
(),
[](
const
StorageEntry
*
a
,
const
StorageEntry
*
b
)
{
return
a
->
const_nbits
>
b
->
const_nbits
;
});
for
(
const
StorageEntry
*
e
:
all_entry
)
{
max_layer_num
=
std
::
max
(
max_layer_num
,
static_cast
<
int
>
(
e
->
allocs
.
size
()));
}
// calculate align for each layer of each storage entry.
std
::
vector
<
int
>
align
(
max_layer_num
,
0
);
for
(
const
StorageEntry
*
e
:
all_entry
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
e
->
allocs
.
size
());
i
++
)
{
for
(
const
VarNode
*
buffer
:
e
->
allocs
[
i
])
{
const
AllocateNode
*
alloc
=
shmem_allocs_
[
buffer
];
align
[
i
]
=
std
::
max
(
align
[
i
],
alloc
->
dtype
.
bytes
()
*
alloc
->
dtype
.
lanes
());
align
[
i
]
=
std
::
max
(
align
[
i
],
align_bytes_
);
}
}
}
for
(
const
StorageEntry
*
e
:
all_entry
)
{
PrimExpr
max_inner_offset
=
0
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
e
->
allocs
.
size
());
i
++
)
{
PrimExpr
inner_offset
=
0
;
for
(
const
VarNode
*
buffer
:
e
->
allocs
[
i
])
{
const
AllocateNode
*
alloc
=
shmem_allocs_
[
buffer
];
auto
alignment
=
align
[
i
];
// Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for example) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes
if
(
shmem_alignment_map_
.
find
(
buffer
)
!=
shmem_alignment_map_
.
end
())
{
alignment
=
std
::
max
(
align
[
i
],
shmem_alignment_map_
[
buffer
]);
}
PrimExpr
start_offset
=
merged_alloc_size_
+
inner_offset
;
PrimExpr
aligned_offset
=
indexdiv
(
start_offset
+
alignment
-
1
,
alignment
)
*
alignment
;
buffer_byte_offsets_
[
buffer
]
=
aligned_offset
;
inner_offset
=
aligned_offset
-
merged_alloc_size_
+
alloc
->
extents
[
0
]
*
alloc
->
dtype
.
bytes
()
*
alloc
->
dtype
.
lanes
();
}
max_inner_offset
=
max
(
max_inner_offset
,
inner_offset
);
}
merged_alloc_size_
+=
max_inner_offset
;
}
if
(
verbose_
)
{
...
...
@@ -626,18 +609,199 @@ private:
using
StmtEntry
=
SharedMemLinearAccessPatternFinder
::
StmtEntry
;
using
StmtAttr
=
SharedMemLinearAccessPatternFinder
::
StmtAttr
;
struct
StorageEntry
{
// The constant size of the buffer in bits, only used if it is constant
uint64_t
const_nbits
{
0
};
// Allocs that shares this entry.
// The inner vector means a "layer"
// For example, it we need to allocate C in the memory of A and B:
// | A: 4096 bytes | B: 4096 bytes |
// | C: 8192 bytes |
// Then the allocs = {{A, B}, {C}}
std
::
vector
<
std
::
vector
<
const
VarNode
*>>
allocs
;
// Metadata about a single shared-memory allocation prior to merging. This
// is used to build lifetimes, alignment requirements, and final offsets.
struct
BufInfo
{
const
VarNode
*
var
{
nullptr
};
std
::
string
name
;
PrimExpr
size_expr
;
std
::
optional
<
int64_t
>
const_size_bytes
;
// in bytes if compile-time known.
int
alignment
{
0
};
// required byte alignment.
int
start
{
0
};
// first statement index touching the buf.
int
end
{
0
};
// one-past-last statement index.
DataType
size_dtype
{
DataType
::
Int
(
32
)};
};
// Interval describing the liveness window of a (constant-sized) allocation.
struct
Interval
{
int
start
{
0
};
int
end
{
0
};
size_t
size_bytes
{
0
};
int
alignment
{
0
};
const
VarNode
*
var
{
nullptr
};
};
// Result of a linear-scan arena packing. Offsets contain the byte offset for
// each constant-sized buffer, arena_size is the total constant footprint.
struct
ArenaPlan
{
size_t
arena_size
{
0
};
std
::
unordered_map
<
const
VarNode
*
,
size_t
>
offsets
;
};
static
size_t
AlignUpSize
(
size_t
value
,
size_t
alignment
)
{
if
(
alignment
==
0
)
{
return
value
;
}
size_t
remainder
=
value
%
alignment
;
if
(
remainder
==
0
)
{
return
value
;
}
return
value
+
(
alignment
-
remainder
);
}
struct
FreeBlock
{
size_t
offset
{
0
};
size_t
size
{
0
};
};
class
FreeList
{
public:
std
::
optional
<
size_t
>
Allocate
(
size_t
need
,
size_t
alignment
)
{
// Best-fit search: pick the slot that wastes the least space after
// alignment.
int
best
=
-
1
;
size_t
best_waste
=
std
::
numeric_limits
<
size_t
>::
max
();
for
(
int
i
=
0
,
n
=
static_cast
<
int
>
(
blocks_
.
size
());
i
<
n
;
++
i
)
{
size_t
aligned
=
AlignUpSize
(
blocks_
[
i
].
offset
,
alignment
);
size_t
head
=
aligned
-
blocks_
[
i
].
offset
;
if
(
head
<=
blocks_
[
i
].
size
&&
(
blocks_
[
i
].
size
-
head
)
>=
need
)
{
size_t
waste
=
blocks_
[
i
].
size
-
head
-
need
;
if
(
waste
<
best_waste
)
{
best_waste
=
waste
;
best
=
i
;
}
}
}
if
(
best
<
0
)
{
return
std
::
nullopt
;
}
FreeBlock
blk
=
blocks_
[
best
];
size_t
aligned
=
AlignUpSize
(
blk
.
offset
,
alignment
);
size_t
head
=
aligned
-
blk
.
offset
;
size_t
tail
=
blk
.
size
-
head
-
need
;
blocks_
.
erase
(
blocks_
.
begin
()
+
best
);
if
(
head
)
{
blocks_
.
push_back
({
blk
.
offset
,
head
});
}
if
(
tail
)
{
blocks_
.
push_back
({
aligned
+
need
,
tail
});
}
Normalize
();
return
aligned
;
}
void
Free
(
size_t
offset
,
size_t
size
)
{
if
(
size
==
0
)
return
;
blocks_
.
push_back
({
offset
,
size
});
Normalize
();
}
private:
void
Normalize
()
{
if
(
blocks_
.
empty
())
return
;
std
::
sort
(
blocks_
.
begin
(),
blocks_
.
end
(),
[](
const
FreeBlock
&
a
,
const
FreeBlock
&
b
)
{
return
a
.
offset
<
b
.
offset
;
});
std
::
vector
<
FreeBlock
>
merged
;
merged
.
reserve
(
blocks_
.
size
());
for
(
const
FreeBlock
&
blk
:
blocks_
)
{
if
(
merged
.
empty
())
{
merged
.
push_back
(
blk
);
continue
;
}
FreeBlock
&
last
=
merged
.
back
();
size_t
last_end
=
last
.
offset
+
last
.
size
;
if
(
blk
.
offset
<=
last_end
)
{
size_t
blk_end
=
blk
.
offset
+
blk
.
size
;
if
(
blk_end
>
last_end
)
{
last
.
size
=
blk_end
-
last
.
offset
;
}
}
else
{
merged
.
push_back
(
blk
);
}
}
blocks_
=
std
::
move
(
merged
);
}
std
::
vector
<
FreeBlock
>
blocks_
;
};
struct
ActiveInterval
{
int
end
{
0
};
size_t
offset
{
0
};
size_t
size
{
0
};
const
VarNode
*
var
{
nullptr
};
bool
operator
>
(
const
ActiveInterval
&
other
)
const
{
return
end
>
other
.
end
;
}
};
static
ArenaPlan
LinearScanPack
(
std
::
vector
<
Interval
>
intervals
)
{
// Process intervals in program order so lifetimes correspond to the
// linearised CFG.
std
::
sort
(
intervals
.
begin
(),
intervals
.
end
(),
[](
const
Interval
&
lhs
,
const
Interval
&
rhs
)
{
if
(
lhs
.
start
!=
rhs
.
start
)
{
return
lhs
.
start
<
rhs
.
start
;
}
if
(
lhs
.
size_bytes
!=
rhs
.
size_bytes
)
{
return
lhs
.
size_bytes
>
rhs
.
size_bytes
;
}
return
lhs
.
var
<
rhs
.
var
;
});
std
::
priority_queue
<
ActiveInterval
,
std
::
vector
<
ActiveInterval
>
,
std
::
greater
<
ActiveInterval
>>
active
;
FreeList
freelist
;
size_t
arena_top
=
0
;
std
::
unordered_map
<
const
VarNode
*
,
size_t
>
offsets
;
// Expire intervals that end before or at program counter `pc`.
auto
retire
=
[
&
](
int
pc
)
{
while
(
!
active
.
empty
()
&&
active
.
top
().
end
<=
pc
)
{
const
ActiveInterval
top
=
active
.
top
();
active
.
pop
();
freelist
.
Free
(
top
.
offset
,
top
.
size
);
}
};
for
(
const
Interval
&
interval
:
intervals
)
{
retire
(
interval
.
start
);
size_t
offset
=
0
;
// Try to recycle previously freed memory first; fall back to bumping the
// arena.
if
(
auto
slot
=
freelist
.
Allocate
(
interval
.
size_bytes
,
interval
.
alignment
))
{
offset
=
slot
.
value
();
}
else
{
offset
=
AlignUpSize
(
arena_top
,
interval
.
alignment
);
arena_top
=
offset
+
interval
.
size_bytes
;
}
active
.
push
(
ActiveInterval
{
interval
.
end
,
offset
,
interval
.
size_bytes
,
interval
.
var
});
offsets
[
interval
.
var
]
=
offset
;
}
return
ArenaPlan
{
arena_top
,
std
::
move
(
offsets
)};
}
PrimExpr
AlignPrimExpr
(
const
PrimExpr
&
value
,
int
alignment
)
const
{
if
(
alignment
<=
1
)
{
return
value
;
}
DataType
dtype
=
value
.
dtype
();
ICHECK
(
dtype
.
is_int
()
||
dtype
.
is_uint
())
<<
"Expected integer dtype for alignment, but got "
<<
dtype
;
PrimExpr
align_expr
=
make_const
(
dtype
,
alignment
);
PrimExpr
adjust
=
make_const
(
dtype
,
alignment
-
1
);
return
indexdiv
(
value
+
adjust
,
align_expr
)
*
align_expr
;
}
// Event entry in liveness analysis
struct
EventEntry
{
// variables we generate
...
...
@@ -905,173 +1069,228 @@ private:
void
PlanMemory
(
const
std
::
vector
<
StmtEntry
>
&
seq
,
const
std
::
unordered_map
<
const
Object
*
,
StmtAttr
>
&
stmt_attrs
)
{
std
::
unordered_set
<
const
VarNode
*>
inplace_flag
;
buffer_byte_offsets_
.
clear
();
(
void
)
stmt_attrs
;
if
(
shmem_allocs_
.
empty
())
{
merged_alloc_size_
=
make_const
(
DataType
::
Int
(
64
),
0
);
return
;
}
// Discover the first and last touch for every allocation.
std
::
unordered_map
<
const
VarNode
*
,
int
>
start_index
;
std
::
unordered_map
<
const
VarNode
*
,
int
>
end_index
;
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
auto
it
=
event_map_
.
find
(
seq
[
i
].
stmt
);
// scope_pair_offset <= 0 means it is either
// - leaf stmt(offset = 0)
// - end of scope(offset < 0)
// In both cases, we need to handle the kill event correctly
auto
is_leaf_alloc
=
[
&
](
const
VarNode
*
var
)
{
return
seq
[
i
].
scope_pair_offset
==
0
&&
std
::
find
(
it
->
second
.
gen
.
begin
(),
it
->
second
.
gen
.
end
(),
var
)
!=
it
->
second
.
gen
.
end
();
};
if
(
it
!=
event_map_
.
end
()
&&
seq
[
i
].
scope_pair_offset
<=
0
)
{
if
(
it
==
event_map_
.
end
())
continue
;
for
(
const
VarNode
*
var
:
it
->
second
.
gen
)
{
start_index
.
emplace
(
var
,
static_cast
<
int
>
(
i
));
}
for
(
const
VarNode
*
var
:
it
->
second
.
kill
)
{
if
(
!
is_leaf_alloc
(
var
))
this
->
Free
(
var
);
end_index
[
var
]
=
std
::
max
(
end_index
[
var
],
static_cast
<
int
>
(
i
)
+
1
);
}
}
// scope_pair_offset >= 0 means it is either
// - leaf stmt(offset = 0)
// - beginning of scope(offset < 0)
// In both cases, we need to handle the gen event correctly
if
(
it
!=
event_map_
.
end
()
&&
seq
[
i
].
scope_pair_offset
>=
0
)
{
for
(
const
VarNode
*
var
:
it
->
second
.
gen
)
{
ICHECK
(
shmem_allocs_
.
count
(
var
));
const
AllocateNode
*
alloc
=
shmem_allocs_
[
var
];
StorageEntry
*
dst_entry
=
FindAlloc
(
alloc
);
alloc_map_
[
var
]
=
dst_entry
;
const
int
seq_len
=
static_cast
<
int
>
(
seq
.
size
());
for
(
const
auto
&
kv
:
start_index
)
{
if
(
!
end_index
.
count
(
kv
.
first
))
{
end_index
[
kv
.
first
]
=
seq_len
;
}
}
if
(
it
!=
event_map_
.
end
()
&&
seq
[
i
].
scope_pair_offset
<=
0
)
{
for
(
const
VarNode
*
var
:
it
->
second
.
kill
)
{
if
(
is_leaf_alloc
(
var
))
this
->
Free
(
var
);
std
::
vector
<
BufInfo
>
buf_infos
;
buf_infos
.
reserve
(
shmem_allocs_
.
size
());
// Build a BufInfo for all allocations that participate in liveness.
for
(
const
auto
&
kv
:
shmem_allocs_
)
{
const
VarNode
*
var
=
kv
.
first
;
auto
start_it
=
start_index
.
find
(
var
);
if
(
start_it
==
start_index
.
end
())
{
continue
;
}
BufInfo
info
;
info
.
var
=
var
;
info
.
name
=
var
->
name_hint
;
info
.
start
=
start_it
->
second
;
info
.
end
=
std
::
max
(
end_index
[
var
],
info
.
start
+
1
);
info
.
alignment
=
align_bytes_
;
auto
align_it
=
shmem_alignment_map_
.
find
(
var
);
if
(
align_it
!=
shmem_alignment_map_
.
end
())
{
info
.
alignment
=
std
::
max
(
info
.
alignment
,
align_it
->
second
);
}
const
AllocateNode
*
alloc
=
kv
.
second
;
int64_t
bytes_per_elem
=
static_cast
<
int64_t
>
(
alloc
->
dtype
.
bytes
()
*
alloc
->
dtype
.
lanes
());
DataType
size_dtype
=
DataType
::
Int
(
32
);
if
(
!
alloc
->
extents
.
empty
())
{
size_dtype
=
alloc
->
extents
[
0
].
dtype
();
}
if
(
!
size_dtype
.
is_int
()
&&
!
size_dtype
.
is_uint
())
{
size_dtype
=
DataType
::
Int
(
32
);
}
/*!
* \brief Allocate new storage entry.
* \param op the allocate node
* \param the size of the allocation in bits
* \return the new storage entry
*/
StorageEntry
*
NewAlloc
(
const
AllocateNode
*
op
,
size_t
const_nbits
)
{
ICHECK
(
op
!=
nullptr
);
// Reuse not successful, allocate a new buffer.
StorageEntry
*
entry
=
arena_
.
make
<
StorageEntry
>
();
entry
->
allocs
.
push_back
({
op
->
buffer_var
.
get
()});
entry
->
const_nbits
=
const_nbits
;
return
entry
;
PrimExpr
size_expr
=
make_const
(
size_dtype
,
bytes_per_elem
);
for
(
const
PrimExpr
&
extent
:
alloc
->
extents
)
{
PrimExpr
e
=
extent
;
if
(
e
.
dtype
()
!=
size_dtype
)
{
e
=
cast
(
size_dtype
,
e
);
}
/*!
* @brief Locate or create a storage entry from free lists to satisfy an
* AllocateNode.
*
* Finds a reusable StorageEntry for the given AllocateNode (constant or
* symbolic size) using two-tiered strategies:
* - For constant-size allocations (>0): prefer a free entry that is >=
* required size; if none, coalesce smaller free constant-size entries until
* the sum meets the request and return a new StorageEntry representing the
* merged space. Very small constant allocations (<= 32 bits) are not reused
* and will allocate a fresh entry.
* - For symbolic-size (unknown at compile time): pick and remove an arbitrary
* entry from the symbolic free list.
*
* If no suitable free entry is found, a fresh StorageEntry is created via
* NewAlloc.
*
* @param op Pointer to the AllocateNode to satisfy. Must be non-null.
* @return StorageEntry* A storage entry that will hold the allocation (may be
* newly created).
*/
StorageEntry
*
FindAlloc
(
const
AllocateNode
*
op
)
{
ICHECK
(
op
!=
nullptr
);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const
uint64_t
match_range
=
16
;
uint64_t
op_elem_bits
=
op
->
dtype
.
bits
()
*
op
->
dtype
.
lanes
();
uint64_t
const_nbits
=
static_cast
<
uint64_t
>
(
op
->
ConstantAllocationSize
()
*
op_elem_bits
);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if
(
const_nbits
>
0
&&
const_nbits
<=
32
)
{
return
NewAlloc
(
op
,
const_nbits
);
}
if
(
const_nbits
!=
0
)
{
// constant allocation.
auto
begin
=
const_free_map_
.
lower_bound
(
0
);
auto
mid
=
const_free_map_
.
lower_bound
(
const_nbits
);
auto
end
=
const_free_map_
.
upper_bound
(
const_nbits
*
match_range
);
// Start looking at the buffer that is bigger than the required size
// first. If we find one, directly allocate the buffer in its location and
// remove its entry in the free list
for
(
auto
it
=
mid
;
it
!=
end
;
++
it
)
{
StorageEntry
*
e
=
it
->
second
;
e
->
const_nbits
=
std
::
max
(
const_nbits
,
e
->
const_nbits
);
const_free_map_
.
erase
(
it
);
it
->
second
->
allocs
.
push_back
({
op
->
buffer_var
.
get
()});
return
e
;
}
// Then start looking at smaller buffers.
// Keep collecting the buffer until the sum of their size exceeds the
// buffer to allocate and finally free all these entry in the free list
std
::
vector
<
std
::
multimap
<
uint64_t
,
StorageEntry
*>::
iterator
>
delete_it
;
// the alloc list for the new entry
std
::
vector
<
std
::
vector
<
const
VarNode
*>>
reuse_allocs
;
uint64_t
mem_ct
=
0
;
for
(
auto
it
=
mid
;
it
!=
begin
;)
{
--
it
;
delete_it
.
push_back
(
it
);
mem_ct
+=
it
->
second
->
const_nbits
;
int
n
=
it
->
second
->
allocs
.
size
();
if
(
n
>
static_cast
<
int
>
(
reuse_allocs
.
size
()))
{
reuse_allocs
.
resize
(
n
,
{});
}
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
for
(
const
VarNode
*
alloc
:
it
->
second
->
allocs
[
i
])
{
reuse_allocs
[
i
].
push_back
(
alloc
);
}
}
if
(
mem_ct
>=
const_nbits
)
{
break
;
size_expr
=
size_expr
*
e
;
}
info
.
size_dtype
=
size_dtype
;
info
.
size_expr
=
size_expr
;
int64_t
const_extent
=
alloc
->
ConstantAllocationSize
();
if
(
const_extent
>=
0
)
{
info
.
const_size_bytes
=
const_extent
*
bytes_per_elem
;
}
reuse_allocs
.
push_back
({
op
->
buffer_var
.
get
()});
if
(
mem_ct
!=
0
)
{
StorageEntry
*
e
=
arena_
.
make
<
StorageEntry
>
();
e
->
const_nbits
=
std
::
max
(
const_nbits
,
mem_ct
);
e
->
allocs
=
reuse_allocs
;
for
(
auto
it
:
delete_it
)
{
const_free_map_
.
erase
(
it
);
buf_infos
.
push_back
(
std
::
move
(
info
));
}
return
e
;
// Stable order so the later passes have deterministic behaviour.
std
::
sort
(
buf_infos
.
begin
(),
buf_infos
.
end
(),
[](
const
BufInfo
&
a
,
const
BufInfo
&
b
)
{
if
(
a
.
start
!=
b
.
start
)
return
a
.
start
<
b
.
start
;
if
(
a
.
end
!=
b
.
end
)
return
a
.
end
<
b
.
end
;
return
a
.
name
<
b
.
name
;
});
std
::
vector
<
Interval
>
intervals
;
intervals
.
reserve
(
buf_infos
.
size
());
for
(
const
BufInfo
&
info
:
buf_infos
)
{
if
(
!
info
.
const_size_bytes
.
has_value
())
continue
;
// Only constant-sized buffers participate in the arena packing because
// dynamic sizes must be placed sequentially later.
Interval
interval
;
interval
.
start
=
info
.
start
;
interval
.
end
=
info
.
end
;
interval
.
size_bytes
=
static_cast
<
size_t
>
(
std
::
max
<
int64_t
>
(
0
,
info
.
const_size_bytes
.
value
()));
interval
.
alignment
=
info
.
alignment
;
interval
.
var
=
info
.
var
;
intervals
.
push_back
(
interval
);
}
}
else
{
// if its symbolic allocation, just arbitrarily choose one entry to fit in
// because we don't know its actual size
for
(
auto
it
=
sym_free_list_
.
begin
();
it
!=
sym_free_list_
.
end
();
++
it
)
{
StorageEntry
*
e
=
*
it
;
sym_free_list_
.
erase
(
it
);
return
e
;
ArenaPlan
plan
=
LinearScanPack
(
std
::
move
(
intervals
));
size_t
arena_size_const
=
plan
.
arena_size
;
if
(
verbose_
)
{
LOG
(
DEBUG
)
<<
"ArenaPlan (constant buffers): arena_size="
<<
arena_size_const
;
for
(
const
auto
&
kv
:
plan
.
offsets
)
{
const
VarNode
*
var
=
kv
.
first
;
LOG
(
DEBUG
)
<<
" "
<<
var
->
name_hint
<<
" -> offset="
<<
kv
.
second
;
}
}
return
NewAlloc
(
op
,
const_nbits
);
// Cursor tracks the running byte offset within the merged arena.
DataType
offset_dtype
=
buf_infos
.
empty
()
?
DataType
::
Int
(
32
)
:
buf_infos
.
front
().
size_dtype
;
PrimExpr
total_size
=
make_const
(
offset_dtype
,
0
);
PrimExpr
cursor
=
AlignPrimExpr
(
make_const
(
offset_dtype
,
static_cast
<
int64_t
>
(
arena_size_const
)),
align_bytes_
);
auto
CastToOffset
=
[
&
](
PrimExpr
expr
)
->
PrimExpr
{
if
(
expr
.
dtype
()
==
offset_dtype
)
{
return
expr
;
}
return
cast
(
offset_dtype
,
expr
);
};
/*!
* \brief add the storage entry to the buffer var into the free list.
* \param var the buffer var
*/
void
Free
(
const
VarNode
*
var
)
{
auto
it
=
alloc_map_
.
find
(
var
);
ICHECK
(
it
!=
alloc_map_
.
end
());
StorageEntry
*
e
=
it
->
second
;
ICHECK_NE
(
e
->
allocs
.
size
(),
0U
);
// normal free.
if
(
e
->
const_nbits
!=
0
)
{
const_free_map_
.
insert
({
e
->
const_nbits
,
e
});
for
(
const
BufInfo
&
info
:
buf_infos
)
{
PrimExpr
offset_expr
;
auto
it
=
plan
.
offsets
.
find
(
info
.
var
);
if
(
it
!=
plan
.
offsets
.
end
())
{
offset_expr
=
make_const
(
offset_dtype
,
static_cast
<
int64_t
>
(
it
->
second
));
}
else
{
sym_free_list_
.
push_back
(
e
);
// Dynamic-sized buffers are appended after the constant arena.
cursor
=
AlignPrimExpr
(
cursor
,
info
.
alignment
);
PrimExpr
size_expr
=
CastToOffset
(
info
.
size_expr
);
offset_expr
=
cursor
;
cursor
=
offset_expr
+
size_expr
;
}
buffer_byte_offsets_
[
info
.
var
]
=
offset_expr
;
PrimExpr
buf_end
=
offset_expr
+
CastToOffset
(
info
.
size_expr
);
total_size
=
max
(
total_size
,
buf_end
);
}
merged_alloc_size_
=
buf_infos
.
empty
()
?
make_const
(
offset_dtype
,
0
)
:
AlignPrimExpr
(
total_size
,
align_bytes_
);
bool
overlap_detected
=
false
;
if
(
verbose_
)
{
LOG
(
DEBUG
)
<<
"Memory Allocation Plan for "
<<
(
is_dynamic_
?
"Dynamic"
:
"Static"
)
<<
" Shared Memory:"
;
LOG
(
DEBUG
)
<<
" Total Merged Size (aligned): "
<<
merged_alloc_size_
;
for
(
const
BufInfo
&
info
:
buf_infos
)
{
const
PrimExpr
&
offset
=
buffer_byte_offsets_
.
at
(
info
.
var
);
LOG
(
DEBUG
)
<<
" Buffer: "
<<
info
.
name
<<
" start="
<<
info
.
start
<<
" end="
<<
info
.
end
<<
" alignment="
<<
info
.
alignment
<<
" offset="
<<
offset
<<
" size="
<<
info
.
size_expr
;
}
// Sanity check for overlapping constant buffers.
for
(
size_t
i
=
0
;
i
<
buf_infos
.
size
();
++
i
)
{
const
BufInfo
&
a
=
buf_infos
[
i
];
auto
a_off_imm
=
buffer_byte_offsets_
.
at
(
a
.
var
).
as
<
IntImmNode
>
();
if
(
!
a
.
const_size_bytes
.
has_value
()
||
a_off_imm
==
nullptr
)
continue
;
int64_t
a_off
=
a_off_imm
->
value
;
int64_t
a_end
=
a_off
+
a
.
const_size_bytes
.
value
();
for
(
size_t
j
=
i
+
1
;
j
<
buf_infos
.
size
();
++
j
)
{
const
BufInfo
&
b
=
buf_infos
[
j
];
auto
b_off_imm
=
buffer_byte_offsets_
.
at
(
b
.
var
).
as
<
IntImmNode
>
();
if
(
!
b
.
const_size_bytes
.
has_value
()
||
b_off_imm
==
nullptr
)
continue
;
bool
live_overlap
=
!
(
a
.
end
<=
b
.
start
||
b
.
end
<=
a
.
start
);
if
(
!
live_overlap
)
continue
;
int64_t
b_off
=
b_off_imm
->
value
;
int64_t
b_end
=
b_off
+
b
.
const_size_bytes
.
value
();
bool
mem_overlap
=
!
(
a_end
<=
b_off
||
b_end
<=
a_off
);
if
(
mem_overlap
)
{
overlap_detected
=
true
;
LOG
(
WARNING
)
<<
"Buffer overlap detected between "
<<
a
.
name
<<
" and "
<<
b
.
name
<<
" (lifetime overlap with "
<<
"offset ranges ["
<<
a_off
<<
", "
<<
a_end
<<
") and ["
<<
b_off
<<
", "
<<
b_end
<<
"))."
;
}
}
}
}
if
(
overlap_detected
)
{
LOG
(
WARNING
)
<<
"Detected overlapping constant buffers; falling back to "
<<
"sequential allocation without reuse."
;
buffer_byte_offsets_
.
clear
();
// In the fallback path we simply lay buffers out sequentially.
PrimExpr
new_cursor
=
make_const
(
offset_dtype
,
0
);
PrimExpr
new_total
=
make_const
(
offset_dtype
,
0
);
for
(
const
BufInfo
&
info
:
buf_infos
)
{
new_cursor
=
AlignPrimExpr
(
new_cursor
,
info
.
alignment
);
PrimExpr
size_expr
=
CastToOffset
(
info
.
size_expr
);
buffer_byte_offsets_
[
info
.
var
]
=
new_cursor
;
PrimExpr
buf_end
=
new_cursor
+
size_expr
;
new_total
=
max
(
new_total
,
buf_end
);
new_cursor
=
buf_end
;
}
merged_alloc_size_
=
buf_infos
.
empty
()
?
make_const
(
offset_dtype
,
0
)
:
AlignPrimExpr
(
new_total
,
align_bytes_
);
}
}
// Whether enable dynamic analysis.
bool
is_dynamic_
{
true
};
...
...
@@ -1095,14 +1314,6 @@ private:
bool
allocated_
{
false
};
// Locations of free ops.
std
::
unordered_map
<
const
Object
*
,
EventEntry
>
event_map_
;
// constant size free map.
std
::
multimap
<
uint64_t
,
StorageEntry
*>
const_free_map_
;
// symbolic free list, for non constant items.
std
::
list
<
StorageEntry
*>
sym_free_list_
;
// The allocation assign map
std
::
unordered_map
<
const
VarNode
*
,
StorageEntry
*>
alloc_map_
;
/*! \brief allocator of all the StorageEntry*/
support
::
Arena
arena_
;
// The mapping of buffer bytes alignment
std
::
unordered_map
<
const
VarNode
*
,
int
>
shmem_alignment_map_
;
};
...
...
@@ -1150,11 +1361,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.MergeSharedMemoryAllocations"
,
MergeSharedMemoryAllocations
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/multi_version_buffer_rewriter.cc
View file @
bbbf4207
...
...
@@ -57,7 +57,7 @@ public:
// Check reads from global
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
GetRef
<
Stmt
>
(
op
));
/*body*/
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
auto
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
auto
reads
=
access
[
0
];
Role
role
=
Role
::
kProducer
;
...
...
@@ -253,7 +253,8 @@ private:
}
static
Buffer
RewriteAllocBuffer
(
const
Buffer
&
buffer
,
int
num_versions
)
{
ObjectPtr
<
BufferNode
>
new_buffer
=
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
ObjectPtr
<
BufferNode
>
new_buffer
=
tvm
::
ffi
::
make_object
<
BufferNode
>
(
*
(
buffer
.
get
()));
new_buffer
->
shape
.
insert
(
new_buffer
->
shape
.
begin
(),
PrimExpr
(
num_versions
));
if
(
!
new_buffer
->
strides
.
empty
())
{
ICHECK
(
new_buffer
->
strides
.
size
()
+
1
==
new_buffer
->
shape
.
size
());
...
...
@@ -493,10 +494,10 @@ tvm::transform::Pass MultiVersionBuffer() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.MultiVersionBuffer"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.MultiVersionBuffer"
,
MultiVersionBuffer
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/persist_threadblock.cc
View file @
bbbf4207
...
...
@@ -59,10 +59,10 @@ tvm::transform::Pass PersistThreadblock() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.PersistThreadblock"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.PersistThreadblock"
,
PersistThreadblock
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/pipeline_planning.cc
View file @
bbbf4207
...
...
@@ -103,7 +103,7 @@ private:
ICHECK
(
call
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()));
auto
var
=
call
->
args
[
1
].
as
<
VarNode
>
();
ICHECK
(
var
);
auto
it
=
buffer_data_to_buffer_
.
find
(
GetRef
<
Var
>
(
var
));
auto
it
=
buffer_data_to_buffer_
.
find
(
tvm
::
ffi
::
GetRef
<
Var
>
(
var
));
ICHECK
(
it
!=
buffer_data_to_buffer_
.
end
());
return
(
*
it
).
second
;
};
...
...
@@ -210,7 +210,7 @@ private:
if
(
const
auto
*
load
=
op
->
args
[
0
].
as
<
BufferLoadNode
>
())
{
buffer_region
=
BufferRegion
::
FullRegion
(
load
->
buffer
);
}
else
if
(
const
auto
*
var_node
=
op
->
args
[
0
].
as
<
VarNode
>
())
{
Var
data_var
=
GetRef
<
Var
>
(
var_node
);
Var
data_var
=
tvm
::
ffi
::
GetRef
<
Var
>
(
var_node
);
auto
it
=
buffer_data_to_buffer_
.
find
(
data_var
);
if
(
it
!=
buffer_data_to_buffer_
.
end
())
{
buffer_region
=
BufferRegion
::
FullRegion
((
*
it
).
second
);
...
...
@@ -223,7 +223,7 @@ private:
}
else
if
(
op
->
op
.
same_as
(
builtin
::
tvm_access_ptr
()))
{
const
VarNode
*
buffer_var
=
op
->
args
[
1
].
as
<
VarNode
>
();
ICHECK
(
buffer_var
);
auto
it
=
buffer_data_to_buffer_
.
find
(
GetRef
<
Var
>
(
buffer_var
));
auto
it
=
buffer_data_to_buffer_
.
find
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buffer_var
));
if
(
it
!=
buffer_data_to_buffer_
.
end
())
{
const
Buffer
&
buffer
=
(
*
it
).
second
;
const
BufferRegion
buffer_region
=
BufferRegion
::
FullRegion
(
buffer
);
...
...
@@ -402,7 +402,7 @@ private:
if
(
TargetHasAsyncCopy
(
target_
)
&&
use_async_copy_
)
annotations
.
Set
(
tir
::
attr
::
software_pipeline_async_stages
,
Array
<
Integer
>
{
0
});
auto
for_node
=
GetRef
<
For
>
(
loop
);
auto
for_node
=
tvm
::
ffi
::
GetRef
<
For
>
(
loop
);
for_node
.
CopyOnWrite
()
->
annotations
=
annotations
;
return
for_node
;
}
...
...
@@ -728,10 +728,10 @@ tvm::transform::Pass PipelinePlanning() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.PipelinePlanning"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.PipelinePlanning"
,
PipelinePlanning
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/simplify.cc
View file @
bbbf4207
...
...
@@ -23,6 +23,7 @@ namespace tvm {
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
using
namespace
arith
;
struct
SimplifyConfigNode
:
public
AttrsNodeReflAdapter
<
SimplifyConfigNode
>
{
...
...
@@ -62,8 +63,8 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
"branch"
,
refl
::
DefaultValue
(
false
));
}
static
constexpr
const
char
*
_type_key
=
"tl.transform.SimplifyConfig"
;
TVM_FFI_DECLARE_FINAL_OBJECT_INFO
(
SimplifyConfigNode
,
BaseAttrsNode
);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL
(
"tl.transform.SimplifyConfig"
,
SimplifyConfigNode
,
BaseAttrsNode
);
RewriteSimplifier
::
Extension
GetEnabledExtensions
()
const
{
RewriteSimplifier
::
Extension
flags
=
RewriteSimplifier
::
kNone
;
...
...
@@ -209,12 +210,11 @@ CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
class
SimplifyConfig
:
public
Attrs
{
public:
TVM_DEFINE_
NOTNULLABLE_
OBJECT_REF_METHODS
(
SimplifyConfig
,
Attrs
,
TVM_
FFI_
DEFINE_OBJECT_REF_METHODS
_NOTNULLABLE
(
SimplifyConfig
,
Attrs
,
SimplifyConfigNode
);
};
TVM_FFI_STATIC_INIT_BLOCK
({
SimplifyConfigNode
::
RegisterReflection
();
}
);
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
SimplifyConfigNode
::
RegisterReflection
();
}
TVM_REGISTER_NODE_TYPE
(
SimplifyConfigNode
);
TVM_REGISTER_PASS_CONFIG_OPTION
(
"tl.Simplify"
,
SimplifyConfig
);
class
StmtSimplifier
:
public
IRMutatorWithAnalyzer
{
...
...
@@ -391,7 +391,7 @@ private:
if
(
can_inline
&&
!
used_in_buffer_def
)
{
return
body
;
}
else
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
auto
n
=
this
->
CopyOnWrite
(
op
);
n
->
value
=
std
::
move
(
value
);
...
...
@@ -522,10 +522,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.Simplify"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.Simplify"
,
Simplify
);
}
);
}
}
// namespace tl
}
// namespace tvm
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment