Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
667632cc
Unverified
Commit
667632cc
authored
Dec 22, 2025
by
guchaoyang
Committed by
GitHub
Dec 22, 2025
Browse files
Merge branch 'main' into dcu
parents
d6dd2ddf
a874e4e8
Changes
343
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
202 additions
and
19 deletions
+202
-19
src/transform/common/assume.h
src/transform/common/assume.h
+28
-0
src/transform/hoist_nonrestrict_params.cc
src/transform/hoist_nonrestrict_params.cc
+133
-0
src/transform/inject_assumes.cc
src/transform/inject_assumes.cc
+41
-19
No files found.
Too many changes to show.
To preserve performance only
343 of 343+
files are displayed.
Plain diff
Email patch
src/transform/common/assume.h
0 → 100644
View file @
667632cc
/*!
* \file assume.h
* \brief Utils on assume statements
*/
#ifndef TVM_TL_TRANSFORM_COMMON_ASSUME_H_
#define TVM_TL_TRANSFORM_COMMON_ASSUME_H_
#include "tvm/tir/stmt.h"
#include <optional>
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
// Get the expression inside an assume statement, if any. Returns nullopt if
// the statement is not an assume statement.
std
::
optional
<
PrimExpr
>
GetAssumeExprInEvaluateForm
(
Stmt
stmt
);
// Check if a statement is an assume statement.
bool
IsAssumeInEvaluateForm
(
const
Stmt
&
stmt
);
}
// namespace tl
}
// namespace tvm
#endif // TVM_TL_TRANSFORM_COMMON_ASSUME_H_
\ No newline at end of file
src/transform/hoist_nonrestrict_params.cc
0 → 100644
View file @
667632cc
/*
* Hoist tl.non_restrict_params block annotation(s) to PrimFunc attribute.
*
* Previously, we only looked at the root block. This version recursively
* scans all blocks, unions any tl.non_restrict_params entries it finds,
* merges with any existing PrimFunc-level attribute, then writes the
* deduplicated result back to the PrimFunc attrs. This makes annotation
* placement within the function body flexible for frontends.
*/
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tvm
::
tir
;
class
NonRestrictCollector
:
public
StmtVisitor
{
public:
void
Collect
(
const
Stmt
&
stmt
)
{
VisitStmt
(
stmt
);
}
Array
<
Var
>
Result
()
const
{
Array
<
Var
>
out
;
out
.
reserve
(
collected_
.
size
());
for
(
const
Var
&
v
:
collected_
)
out
.
push_back
(
v
);
return
out
;
}
private:
static
std
::
string
NormalizeName
(
const
std
::
string
&
s
)
{
if
(
s
.
size
()
>=
8
&&
s
.
rfind
(
"_handle"
)
==
s
.
size
()
-
7
)
{
return
s
.
substr
(
0
,
s
.
size
()
-
7
);
}
return
s
;
}
void
MaybeInsert
(
const
Var
&
v
)
{
if
(
!
v
.
defined
())
return
;
const
VarNode
*
p
=
v
.
get
();
if
(
seen_ptr_
.
count
(
p
))
return
;
// Also dedup by normalized name to be robust w.r.t recreated Vars
std
::
string
norm
=
NormalizeName
(
v
->
name_hint
);
if
(
seen_name_
.
count
(
norm
))
return
;
seen_ptr_
.
insert
(
p
);
seen_name_
.
insert
(
std
::
move
(
norm
));
collected_
.
push_back
(
v
);
}
void
VisitStmt_
(
const
BlockNode
*
op
)
final
{
auto
it
=
op
->
annotations
.
find
(
attr
::
kNonRestrictParams
);
if
(
it
!=
op
->
annotations
.
end
())
{
if
(
const
auto
*
arr
=
(
*
it
).
second
.
as
<
ffi
::
ArrayObj
>
())
{
// Downcast directly to Array<Var> for convenience
Array
<
Var
>
vars
=
tvm
::
Downcast
<
Array
<
Var
>>
((
*
it
).
second
);
for
(
const
Var
&
v
:
vars
)
{
MaybeInsert
(
v
);
}
}
}
// Recurse into child statements
StmtVisitor
::
VisitStmt_
(
op
);
}
std
::
vector
<
Var
>
collected_
;
std
::
unordered_set
<
const
VarNode
*>
seen_ptr_
;
std
::
unordered_set
<
std
::
string
>
seen_name_
;
};
static
PrimFunc
HoistNonRestrictParams
(
PrimFunc
f
)
{
if
(
!
f
.
defined
())
return
f
;
NonRestrictCollector
collector
;
collector
.
Collect
(
f
->
body
);
Array
<
Var
>
from_blocks
=
collector
.
Result
();
// Merge with any existing PrimFunc-level attribute if present
if
(
auto
opt_existing
=
f
->
GetAttr
<
Array
<
Var
>>
(
attr
::
kNonRestrictParams
))
{
for
(
const
Var
&
v
:
opt_existing
.
value
())
{
// Reuse the collector's dedup logic by temporarily constructing a new
// collector Alternatively, do a small inline dedup mirroring MaybeInsert
// Here we inline a simplified pointer-based dedup plus name-based
// fallback
bool
exists
=
false
;
for
(
const
Var
&
cur
:
from_blocks
)
{
if
(
cur
.
get
()
==
v
.
get
()
||
cur
->
name_hint
==
v
->
name_hint
)
{
exists
=
true
;
break
;
}
}
if
(
!
exists
)
from_blocks
.
push_back
(
v
);
}
}
if
(
from_blocks
.
empty
())
return
f
;
return
WithAttr
(
std
::
move
(
f
),
attr
::
kNonRestrictParams
,
std
::
move
(
from_blocks
));
}
namespace
transform
{
tvm
::
transform
::
Pass
HoistNonRestrictParams
()
{
auto
pass_func
=
[](
PrimFunc
f
,
const
IRModule
&
,
const
tvm
::
transform
::
PassContext
&
)
{
return
tvm
::
tl
::
HoistNonRestrictParams
(
std
::
move
(
f
));
};
return
tvm
::
tir
::
transform
::
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.HoistNonRestrictParams"
,
{});
}
}
// namespace transform
}
// namespace tl
}
// namespace tvm
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.HoistNonRestrictParams"
,
tvm
::
tl
::
transform
::
HoistNonRestrictParams
);
}
src/transform/inject_assumes.cc
View file @
667632cc
/*!
* \file inject_assumes.cc
* \brief Inject assumes on buffer's shape boundary check. Also convert
* existing assumes to AttrNodes.
*/
#include "common/assume.h"
#include "tvm/arith/analyzer.h"
#include "tvm/ffi/optional.h"
#include "tvm/ir/expr.h"
...
...
@@ -6,9 +12,11 @@
#include "tvm/node/structural_hash.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/op.h"
#include "tvm/tir/stmt.h"
#include "tvm/tir/stmt_functor.h"
#include "tvm/tir/transform.h"
#include <sstream>
namespace
tvm
::
tl
{
...
...
@@ -26,11 +34,12 @@ public:
}
private:
struct
Ass
ert
Creator
{
struct
Ass
ume
Creator
{
struct
Item
{
PrimExpr
expr
;
std
::
vector
<
Buffer
>
buffers
;
};
tvm
::
StructuralHash
sh
;
tvm
::
StructuralEqual
se
;
// grouped by expr, since the amount of variadic shape symbols is usually
...
...
@@ -52,6 +61,7 @@ private:
items
[
*
it
].
buffers
.
push_back
(
buffer
);
}
}
void
addBuffer
(
Buffer
buf
)
{
for
(
auto
shape
:
buf
->
shape
)
{
if
(
shape
->
IsInstance
<
IntImmNode
>
())
...
...
@@ -59,10 +69,12 @@ private:
addExpr
(
shape
,
buf
);
}
}
Stmt
build
(
Stmt
body
)
{
auto
analyzer
=
arith
::
Analyzer
{};
for
(
const
auto
&
e
:
items
)
{
auto
simplified
=
analyzer
.
Simplify
(
GT
(
e
.
expr
,
0
));
auto
simplified
=
analyzer
.
Simplify
(
GT
(
e
.
expr
,
make_zero
(
e
.
expr
->
dtype
)));
std
::
stringstream
ss
;
ss
<<
"Buffer shape should be greater than 0: shape `"
<<
e
.
expr
<<
"` from buffer "
;
...
...
@@ -77,32 +89,37 @@ private:
return
body
;
}
};
Stmt
VisitStmt_
(
const
DeclBufferNode
*
op
)
final
{
auto
body
=
VisitStmt
(
op
->
body
);
Ass
ert
Creator
c
;
Ass
ume
Creator
c
;
c
.
addBuffer
(
op
->
buffer
);
return
DeclBuffer
(
op
->
buffer
,
c
.
build
(
body
),
op
->
span
);
}
std
::
optional
<
PrimExpr
>
getAssumeExpr
(
Stmt
stmt
)
{
auto
eval
=
stmt
.
as
<
EvaluateNode
>
();
if
(
!
eval
)
return
std
::
nullopt
;
auto
call
=
eval
->
value
.
as
<
CallNode
>
();
if
(
!
call
)
return
std
::
nullopt
;
if
(
!
call
->
op
.
same_as
(
builtin
::
assume
()))
return
std
::
nullopt
;
return
call
->
args
[
0
];
}
Stmt
VisitStmt_
(
const
SeqStmtNode
*
op
)
final
{
struct
AssumeGroup
{
std
::
optional
<
PrimExpr
>
e
;
std
::
vector
<
Stmt
>
stmts
;
};
std
::
vector
<
AssumeGroup
>
groups
=
{
AssumeGroup
{
std
::
nullopt
,
{}}};
for
(
auto
i
=
0
;
i
<
op
->
seq
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
seq
.
size
();
i
++
)
{
auto
stmt
=
VisitStmt
(
op
->
seq
[
i
]);
if
(
auto
e
=
getAssumeExpr
(
stmt
))
{
// Convert assume in evaluate form to assume attribute.
// By default, we have the following IR:
// T.assume(cond1)
// Stmt1
// Stmt2
// T.assume(cond2)
// This SeqStmt will be converted to:
// With(attr::tilelang_assume, cond1) {
// Stmt1
// Stmt2
// }
// With(attr::tilelang_assume, cond2) {
// ...
// }
if
(
auto
e
=
GetAssumeExprInEvaluateForm
(
stmt
))
{
groups
.
push_back
(
AssumeGroup
{
*
e
,
{}});
}
else
{
groups
.
back
().
stmts
.
push_back
(
stmt
);
...
...
@@ -125,10 +142,14 @@ private:
:
SeqStmt
(
groups
[
0
].
stmts
);
// return SeqStmt(groups[0].stmts);
}
Stmt
VisitStmt_
(
const
BlockNode
*
op
)
final
{
auto
body
=
VisitStmt
(
op
->
body
);
AssertCreator
c
;
if
(
root_node
)
{
AssumeCreator
c
;
// NOTE(chaofan): We only inject assumes from function arguments in the
// root block.
if
(
op
->
name_hint
==
"root"
)
{
for
(
auto
item
:
f
->
buffer_map
)
{
c
.
addBuffer
(
item
.
second
);
}
...
...
@@ -139,12 +160,13 @@ private:
for
(
auto
item
:
op
->
match_buffers
)
{
c
.
addBuffer
(
item
->
buffer
);
}
return
Block
(
op
->
iter_vars
,
op
->
reads
,
op
->
writes
,
op
->
name_hint
,
c
.
build
(
body
),
op
->
init
,
op
->
alloc_buffers
,
op
->
match_buffers
,
op
->
annotations
,
op
->
span
);
}
PrimFunc
f
;
bool
root_node
{
true
};
};
using
namespace
tir
::
transform
;
...
...
Prev
1
…
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment