Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
559 additions
and
134 deletions
+559
-134
src/transform/split_host_device.cc
src/transform/split_host_device.cc
+3
-3
src/transform/storage_access.cc
src/transform/storage_access.cc
+42
-10
src/transform/storage_access.h
src/transform/storage_access.h
+7
-0
src/transform/storage_rewrite.cc
src/transform/storage_rewrite.cc
+34
-13
src/transform/thread_storage_sync.cc
src/transform/thread_storage_sync.cc
+82
-14
src/transform/vectorize_loop.cc
src/transform/vectorize_loop.cc
+83
-47
src/transform/warp_specialized_rewriter.cc
src/transform/warp_specialized_rewriter.cc
+6
-6
src/transform/wgmma_sync_rewriter.cc
src/transform/wgmma_sync_rewriter.cc
+2
-2
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
+3
-9
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
+21
-3
testing/python/amd/test_tilelang_test_amd.py
testing/python/amd/test_tilelang_test_amd.py
+20
-23
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
+2
-1
testing/python/issue/test_tilelang_issue_1008.py
testing/python/issue/test_tilelang_issue_1008.py
+53
-0
testing/python/issue/test_tilelang_issue_1115.py
testing/python/issue/test_tilelang_issue_1115.py
+49
-0
testing/python/issue/test_tilelang_issue_1198.py
testing/python/issue/test_tilelang_issue_1198.py
+15
-0
testing/python/issue/test_tilelang_issue_1210.py
testing/python/issue/test_tilelang_issue_1210.py
+36
-0
testing/python/issue/test_tilelang_issue_1237.py
testing/python/issue/test_tilelang_issue_1237.py
+23
-0
testing/python/jit/test_tilelang_jit_gemm_ctypes.py
testing/python/jit/test_tilelang_jit_gemm_ctypes.py
+3
-2
testing/python/jit/test_tilelang_jit_gemm_cython.py
testing/python/jit/test_tilelang_jit_gemm_cython.py
+1
-1
testing/python/jit/test_tilelang_jit_parcompile.py
testing/python/jit/test_tilelang_jit_parcompile.py
+74
-0
No files found.
src/transform/split_host_device.cc
View file @
bbbf4207
...
...
@@ -37,7 +37,7 @@
namespace
tvm
{
namespace
tl
{
using
namespace
ffi
;
namespace
tir
=
tvm
::
tir
;
class
HostDeviceSplitter
:
public
tir
::
StmtMutator
{
...
...
@@ -200,10 +200,10 @@ tvm::transform::Pass SplitHostDevice() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.SplitHostDevice"
,
SplitHostDevice
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/storage_access.cc
View file @
bbbf4207
...
...
@@ -29,6 +29,7 @@
#include <string>
#include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace
tvm
{
...
...
@@ -38,10 +39,11 @@ using namespace tir;
void
TileLangStorageAccessVisitor
::
VisitExpr_
(
const
BufferLoadNode
*
op
)
{
Var
buf
=
op
->
buffer
->
data
;
buffer_data_to_buffer_
.
Set
(
GetRef
<
Var
>
(
buf
.
get
()),
op
->
buffer
);
buffer_data_to_buffer_
.
Set
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buf
.
get
()),
op
->
buffer
);
StorageScope
scope
=
GetScope
(
buf
);
if
(
Enabled
(
buf
.
get
(),
scope
))
{
ICHECK
(
allow_append_
)
<<
GetRef
<
BufferLoad
>
(
op
)
<<
" "
<<
scope
.
to_string
();
ICHECK
(
allow_append_
)
<<
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
)
<<
" "
<<
scope
.
to_string
();
AccessEntry
e
;
e
.
threads
=
env_threads
();
e
.
thread_range
=
this
->
ComputeThreadRange
(
e
.
threads
);
...
...
@@ -65,7 +67,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
curr_stmt_
.
stmt
=
op
;
Var
buf
=
op
->
buffer
->
data
;
buffer_data_to_buffer_
.
Set
(
GetRef
<
Var
>
(
buf
.
get
()),
op
->
buffer
);
buffer_data_to_buffer_
.
Set
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buf
.
get
()),
op
->
buffer
);
StorageScope
scope
=
GetScope
(
buf
);
if
(
Enabled
(
buf
.
get
(),
scope
))
{
AccessEntry
e
;
...
...
@@ -252,7 +254,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
this
->
VisitExpr
(
op
->
condition
);
PrimExpr
real_condition
=
ExtractRealCondition
(
op
->
condition
);
curr_stmt_
.
access
.
clear
();
// Preserve accesses collected from the condition expression so they
// participate in dependency analysis. Otherwise, a write to shared memory
// immediately followed by an if-condition reading that memory would not
// trigger a sync before the if-statement.
std
::
vector
<
AccessEntry
>
cond_access
=
std
::
move
(
curr_stmt_
.
access
);
allow_append_
=
false
;
scope_
.
push_back
(
std
::
vector
<
StmtEntry
>
());
...
...
@@ -265,6 +271,11 @@ void TileLangStorageAccessVisitor::VisitStmt_(const IfThenElseNode *op) {
s
.
stmt
=
op
;
s
.
access
=
Summarize
(
std
::
move
(
scope_
.
back
()),
nullptr
);
scope_
.
pop_back
();
// Merge the condition's access summary into the if-statement's access list
// so the planner can insert a sync before the if when necessary.
if
(
!
cond_access
.
empty
())
{
s
.
access
.
insert
(
s
.
access
.
begin
(),
cond_access
.
begin
(),
cond_access
.
end
());
}
if
(
op
->
else_case
)
{
scope_
.
push_back
(
std
::
vector
<
StmtEntry
>
());
{
...
...
@@ -301,14 +312,32 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
}
void
TileLangStorageAccessVisitor
::
VisitExpr_
(
const
CallNode
*
op
)
{
// Mark async TMA load context so that tvm_access_ptr within the call
// can be tagged accordingly.
auto
is_tma_load
=
[
&
]()
{
if
(
auto
opt
=
op
->
op
.
as
<
Op
>
())
{
const
Op
&
call_op
=
opt
.
value
();
return
call_op
.
same_as
(
tl
::
tma_load
())
||
call_op
.
same_as
(
tl
::
tma_load_im2col
());
}
return
false
;
}();
if
(
is_tma_load
)
{
tma_depth_
++
;
for
(
const
auto
&
a
:
op
->
args
)
{
this
->
VisitExpr
(
a
);
}
tma_depth_
--
;
return
;
}
if
(
op
->
op
.
same_as
(
builtin
::
address_of
()))
{
ICHECK_EQ
(
op
->
args
.
size
(),
1U
);
if
(
auto
load
=
op
->
args
[
0
].
as
<
BufferLoadNode
>
())
{
Buffer
buffer
=
load
->
buffer
;
DataType
dtype
=
buffer
->
dtype
;
const
VarNode
*
buffer_var
=
buffer
->
data
.
as
<
VarNode
>
();
buffer_data_to_buffer_
.
Set
(
GetRef
<
Var
>
(
buffer_var
),
buffer
);
StorageScope
scope
=
GetScope
(
GetRef
<
Var
>
(
buffer_var
));
buffer_data_to_buffer_
.
Set
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buffer_var
),
buffer
);
StorageScope
scope
=
GetScope
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buffer_var
));
Array
<
Range
>
buffer_ranges
;
// from indices to buffer indices
ICHECK
(
buffer
->
shape
.
size
()
==
load
->
indices
.
size
());
...
...
@@ -346,17 +375,18 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
PrimExpr
offset
=
op
->
args
[
2
];
PrimExpr
extent
=
op
->
args
[
3
];
const
IntImmNode
*
flag
=
op
->
args
[
4
].
as
<
IntImmNode
>
();
StorageScope
scope
=
GetScope
(
GetRef
<
Var
>
(
buffer_var
));
StorageScope
scope
=
GetScope
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buffer_var
));
// The buffer scope.
if
(
Enabled
(
buffer_var
,
scope
))
{
ICHECK
(
allow_append_
);
Array
<
Range
>
buffer_ranges
;
if
(
buffer_data_to_buffer_
.
find
(
GetRef
<
Var
>
(
buffer_var
))
==
if
(
buffer_data_to_buffer_
.
find
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buffer_var
))
==
buffer_data_to_buffer_
.
end
())
{
// cannot find buffer map, use the default buffer
buffer_ranges
=
{
Range
::
FromMinExtent
(
offset
,
extent
)};
}
else
{
Buffer
buffer
=
buffer_data_to_buffer_
.
at
(
GetRef
<
Var
>
(
buffer_var
));
Buffer
buffer
=
buffer_data_to_buffer_
.
at
(
tvm
::
ffi
::
GetRef
<
Var
>
(
buffer_var
));
auto
buffer_shape
=
buffer
->
shape
;
// convert 1d offset to multi-dimensional index
auto
linear_to_indices
=
[
this
](
PrimExpr
offset
,
...
...
@@ -387,7 +417,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e
.
threads
=
env_threads
();
e
.
thread_range
=
this
->
ComputeThreadRange
(
e
.
threads
);
e
.
dtype
=
dtype
;
e
.
buffer
=
GetRef
<
Var
>
(
buffer_var
);
e
.
buffer
=
tvm
::
ffi
::
GetRef
<
Var
>
(
buffer_var
);
e
.
buffer_ranges
=
buffer_ranges
;
e
.
is_pointer_access
=
true
;
e
.
touched
=
{
...
...
@@ -395,10 +425,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e
.
scope
=
scope
;
if
(
flag
->
value
&
1
)
{
e
.
type
=
kRead
;
e
.
is_async_copy
=
(
tma_depth_
>
0
);
curr_stmt_
.
access
.
emplace_back
(
e
);
}
if
(
flag
->
value
&
2
)
{
e
.
type
=
kWrite
;
e
.
is_async_copy
=
(
tma_depth_
>
0
);
curr_stmt_
.
access
.
emplace_back
(
e
);
}
}
...
...
src/transform/storage_access.h
View file @
bbbf4207
...
...
@@ -39,6 +39,7 @@ namespace tvm {
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
using
arith
::
IRVisitorWithAnalyzer
;
using
runtime
::
StorageRank
;
using
runtime
::
StorageScope
;
...
...
@@ -83,6 +84,10 @@ public:
bool
double_buffer_write
=
false
;
/*! \brief Whether the access is pointer access */
bool
is_pointer_access
=
false
;
/*! \brief Whether this access originates from an async copy context
* (e.g., inside a TMA load) and therefore multiple writes
* among themselves should not force barriers between them. */
bool
is_async_copy
=
false
;
};
/*! \brief Access pattern about a single statement */
...
...
@@ -159,6 +164,8 @@ private:
bool
allow_append_
{
false
};
// Whether we are in device environment
bool
in_device_env_
{
false
};
// Nesting depth of tma_load/tma_load_im2col calls
int
tma_depth_
{
0
};
// Whether we are inside condition.
int
condition_counter_
{
0
};
// The current double buffer write scope.
...
...
src/transform/storage_rewrite.cc
View file @
bbbf4207
...
...
@@ -544,7 +544,7 @@ public:
}
return
it
->
second
->
alloc_var
;
}
else
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
}
PrimExpr
VisitExpr_
(
const
CallNode
*
op
)
final
{
...
...
@@ -679,7 +679,7 @@ private:
return
!
scope
.
tag
.
empty
()
&&
scope
.
tag
!=
".dyn"
&&
scope
.
tag
!=
".barrier"
&&
scope
.
tag
!=
".workspace"
&&
scope
.
tag
!=
".vtcm"
&&
scope
.
tag
!=
".var"
&&
scope
.
tag
!=
".descriptor"
;
scope
.
tag
.
find
(
".descriptor"
)
!=
0
;
}
// Allocate entry of node.
...
...
@@ -865,7 +865,7 @@ private:
ICHECK_NE
(
e
->
const_nbits
,
0U
);
MemoryInfo
info
;
if
(
e
->
scope
.
tag
!=
".barrier"
&&
e
->
scope
.
tag
!=
".var"
&&
e
->
scope
.
tag
!=
".descriptor"
)
{
e
->
scope
.
tag
.
find
(
".descriptor"
)
!=
0
)
{
info
=
GetMemoryInfo
(
e
->
scope
.
to_string
());
}
uint64_t
total_bits
=
e
->
const_nbits
;
...
...
@@ -978,8 +978,8 @@ private:
ICHECK
(
alloc_info
.
count
(
var
));
const
AllocEntry
&
entry
=
alloc_info
.
at
(
var
);
const
AllocateNode
*
alloc
=
entry
.
alloc
;
auto
storage_scope
=
StorageScope
::
Create
(
GetPtrStorageScope
(
GetRef
<
Var
>
(
var
)));
auto
storage_scope
=
StorageScope
::
Create
(
GetPtrStorageScope
(
tvm
::
ffi
::
GetRef
<
Var
>
(
var
)));
StorageEntry
*
dst_entry
=
nullptr
;
// inplace detection
if
(
detect_inplace
)
{
...
...
@@ -1425,9 +1425,30 @@ public:
void
OnArrayDeclaration
(
const
Var
&
buffer
,
DataType
element_dtype
,
PrimExpr
extent
,
BufferVarInfo
::
DeclarationLocation
declaration_location
)
{
ICHECK
(
info_map_
.
find
(
buffer
.
get
())
==
info_map_
.
end
())
<<
"Array declaration of "
<<
buffer
->
name_hint
<<
" occurred multiple times."
;
auto
it
=
info_map_
.
find
(
buffer
.
get
());
if
(
it
!=
info_map_
.
end
())
{
// The same buffer var may appear in more than one Allocate due to
// upstream transforms (e.g., storage planning/merging). Treat repeated
// declarations as benign and merge metadata instead of erroring.
BufferVarInfo
&
existing
=
it
->
second
;
// Prefer a concrete element dtype if the previous one was a handle.
if
(
existing
.
element_dtype
.
is_handle
()
&&
!
element_dtype
.
is_handle
())
{
existing
.
element_dtype
=
element_dtype
==
DataType
::
Bool
()
?
DataType
::
Int
(
8
).
with_lanes
(
element_dtype
.
lanes
())
:
element_dtype
;
}
// If extent was previously unknown (0) and a concrete extent is
// provided now, record it.
if
(
!
existing
.
extent
.
defined
()
||
is_zero
(
existing
.
extent
))
{
existing
.
extent
=
extent
;
}
// Merge declaration locations (bitwise OR of flags).
existing
.
declaration_location
=
static_cast
<
BufferVarInfo
::
DeclarationLocation
>
(
existing
.
declaration_location
|
declaration_location
);
return
;
}
if
(
element_dtype
==
DataType
::
Bool
())
{
element_dtype
=
DataType
::
Int
(
8
).
with_lanes
(
element_dtype
.
lanes
());
...
...
@@ -1732,7 +1753,7 @@ public:
Var
var
=
(
it
==
rewrite_map_
.
end
())
?
op
->
var
:
it
->
second
.
new_buffer_var
;
if
(
var
.
same_as
(
op
->
var
)
&&
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
return
LetStmt
(
var
,
value
,
body
);
}
...
...
@@ -1985,10 +2006,10 @@ Pass StorageRewrite() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tir.StorageRewrite"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.StorageRewrite"
,
StorageRewrite
);
}
);
}
Pass
PointerValueTypeRewrite
()
{
auto
pass_func
=
[](
PrimFunc
f
,
const
IRModule
&
m
,
const
PassContext
&
ctx
)
{
...
...
@@ -1997,11 +2018,11 @@ Pass PointerValueTypeRewrite() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.PointerValueTypeRewrite"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.PointerValueTypeRewrite"
,
PointerValueTypeRewrite
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/thread_storage_sync.cc
View file @
bbbf4207
...
...
@@ -86,6 +86,7 @@ protected:
// check if sync before statement is needed.
bool
sync_before_stmt
=
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
);
// Apply the syncs added already.
if
(
sync_before_stmt
)
{
reads
.
clear
();
writes
.
clear
();
...
...
@@ -98,7 +99,8 @@ protected:
break
;
}
}
else
if
(
acc
.
type
==
kWrite
)
{
if
(
FindConflict
(
reads
,
acc
,
false
))
{
if
(
FindConflict
(
reads
,
acc
,
false
)
||
FindConflict
(
writes
,
acc
,
false
))
{
sync_before_stmt
=
true
;
break
;
}
...
...
@@ -123,27 +125,51 @@ protected:
writes
.
clear
();
}
}
if
(
sync_before_stmt
)
{
insert_syncs
(
s
.
stmt
);
}
}
if
(
loop
!=
nullptr
)
{
// Check if the loop body contains any reads in the same sync scope.
// If there are reads, we conservatively keep the sync within the loop
// body to preserve per-iteration ordering when needed. If there are no
// reads (e.g., only writes to shared.dyn), we can safely hoist the sync
// to before the loop to avoid redundant barriers.
bool
has_read_in_scope
=
false
;
for
(
const
StmtEntry
&
s
:
seq
)
{
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
&&
acc
.
scope
==
sync_scope_
)
{
has_read_in_scope
=
true
;
break
;
}
}
if
(
has_read_in_scope
)
break
;
}
// If there is a loop-carried dependency, insert a single sync
// before the loop rather than hoisting a sync into the loop body.
// This reduces redundant per-iteration synchronizations for cases
// where each iteration touches disjoint regions (e.g., stmatrix
// writes to shared.dyn) and only a global ordering before/after the
// loop is required.
for
(
size_t
i
=
0
;
i
<
seq
.
size
();
++
i
)
{
const
StmtEntry
&
s
=
seq
[
i
];
if
(
syncs_inserted_
.
count
(
s
.
stmt
)
!=
0
)
break
;
if
(
reads
.
empty
()
&&
writes
.
empty
())
break
;
bool
sync_before_stmt
=
false
;
bool
need_loop_sync
=
false
;
for
(
const
AccessEntry
&
acc
:
s
.
access
)
{
if
(
acc
.
type
==
kRead
)
{
if
(
FindConflict
(
writes
,
acc
,
true
))
{
sync_before_stmt
=
true
;
need_loop_sync
=
true
;
break
;
}
}
else
if
(
acc
.
type
==
kWrite
)
{
if
(
FindConflict
(
reads
,
acc
,
true
))
{
sync_before_stmt
=
true
;
if
(
FindConflict
(
reads
,
acc
,
true
)
||
FindConflict
(
writes
,
acc
,
true
))
{
need_loop_sync
=
true
;
break
;
}
}
else
if
(
acc
.
type
==
kSync
)
{
...
...
@@ -151,8 +177,17 @@ protected:
writes
.
clear
();
}
}
if
(
sync_before_stmt
)
{
if
(
need_loop_sync
)
{
if
(
!
has_read_in_scope
)
{
// Mark the loop itself to receive a sync before it, instead of
// inserting inside the loop body. This ensures a single sync is
// emitted outside the loop and avoids per-iteration overhead.
insert_syncs
(
loop
);
}
else
{
// Fall back to inserting before the first conflicting statement
// inside the loop to maintain correctness when reads are present.
insert_syncs
(
s
.
stmt
);
}
break
;
}
}
...
...
@@ -217,6 +252,14 @@ private:
bool
FindConflict
(
const
AccessEntry
&
prev
,
const
AccessEntry
&
curr
,
bool
loop_carry
)
{
// Special case: ignore conflicts between async-copy writes (e.g., TMA
// loads into shared memory). Multiple async writes do not require
// interspersed barriers among themselves. We still respect conflicts with
// reads to ensure visibility before consumption.
if
(
prev
.
type
==
kWrite
&&
curr
.
type
==
kWrite
&&
prev
.
is_async_copy
&&
curr
.
is_async_copy
)
{
return
false
;
}
// Access to different buffers does not conflict.
if
(
!
prev
.
buffer
.
same_as
(
curr
.
buffer
))
{
return
false
;
...
...
@@ -241,10 +284,15 @@ private:
return
true
;
}
if
(
prev
.
is_pointer_access
||
curr
.
is_pointer_access
)
{
// If either access is a pointer access, conservatively assume a
// conflict. For example, address_of(A[0, 0]) may refer to an unknown
// memory region, so we cannot safely determine if it overlaps with
// previous accesses.
// For accesses created via tvm_access_ptr we may still be able to prove
// disjointness using their byte ranges. If both sides expose a touched
// interval and we can show they don't overlap, skip the conflict.
if
(
prev
.
is_pointer_access
&&
curr
.
is_pointer_access
&&
PointerAccessIsDisjoint
(
prev
,
curr
))
{
return
false
;
}
// Otherwise fall back to the conservative answer: treat them as
// overlapping.
return
true
;
}
...
...
@@ -327,7 +375,7 @@ private:
}
}
if
(
!
(
has_same_index
)
)
{
if
(
!
has_same_index
)
{
break
;
}
}
...
...
@@ -350,6 +398,26 @@ private:
return
range_is_overlap
;
}
bool
PointerAccessIsDisjoint
(
const
AccessEntry
&
lhs
,
const
AccessEntry
&
rhs
)
{
if
(
lhs
.
touched
.
size
()
!=
1
||
rhs
.
touched
.
size
()
!=
1
)
{
return
false
;
}
PrimExpr
lhs_min
=
analyzer_
.
Simplify
(
lhs
.
touched
[
0
].
min
());
PrimExpr
lhs_max
=
analyzer_
.
Simplify
(
lhs
.
touched
[
0
].
max
());
PrimExpr
rhs_min
=
analyzer_
.
Simplify
(
rhs
.
touched
[
0
].
min
());
PrimExpr
rhs_max
=
analyzer_
.
Simplify
(
rhs
.
touched
[
0
].
max
());
if
(
analyzer_
.
CanProve
(
lhs_max
<
rhs_min
,
arith
::
ProofStrength
::
kSymbolicBound
))
{
return
true
;
}
if
(
analyzer_
.
CanProve
(
rhs_max
<
lhs_min
,
arith
::
ProofStrength
::
kSymbolicBound
))
{
return
true
;
}
return
false
;
}
void
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
if
(
op
->
attr_key
==
tvm
::
tir
::
attr
::
thread_extent
)
{
IterVar
iv
=
Downcast
<
IterVar
>
(
op
->
node
);
...
...
@@ -782,10 +850,10 @@ tvm::transform::Pass ThreadSync(const String &storage_scope) {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.ThreadSync"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.ThreadSync"
,
ThreadSync
);
}
);
}
}
// namespace transform
}
// namespace tl
...
...
src/transform/vectorize_loop.cc
View file @
bbbf4207
...
...
@@ -33,6 +33,7 @@
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
...
...
@@ -43,6 +44,7 @@ namespace tvm {
namespace
tl
{
using
namespace
tir
;
using
namespace
ffi
;
/*!
* \brief Perform data type legalization on the given BufferLoadNode pointer.
...
...
@@ -208,6 +210,14 @@ public:
using
ExprFunctor
::
VisitExpr
;
using
StmtMutator
::
operator
();
// Convenience entry to vectorize a loop body without exposing
// the mutator invocation pattern at call sites.
static
Stmt
Vectorize
(
const
Var
&
var
,
const
PrimExpr
&
var_lanes
,
Stmt
body
)
{
TLVectorizer
vec
{
var
,
var_lanes
};
auto
vec_stmt
=
vec
(
std
::
move
(
body
));
return
vec_stmt
;
}
TLVectorizer
(
const
Var
&
var
,
const
PrimExpr
&
var_lanes
)
:
var_
(
var
),
var_lanes_
(
var_lanes
)
{
ramp_
=
Ramp
(
IntImm
(
var
->
dtype
,
0
),
IntImm
(
var
->
dtype
,
1
),
var_lanes
);
...
...
@@ -217,8 +227,9 @@ public:
ICHECK
(
!
need_scalarize_
);
Stmt
ret
=
StmtMutator
::
VisitStmt
(
stmt
);
if
(
need_scalarize_
)
{
auto
scalarized_stmt
=
Scalarize
(
stmt
);
need_scalarize_
=
false
;
return
S
calarize
(
stmt
)
;
return
s
calarize
d_
stmt
;
}
else
{
return
ret
;
}
...
...
@@ -242,7 +253,7 @@ public:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
bool
is_vec_a
=
a
.
dtype
().
is_scalable_or_fixed_length_vector
();
bool
is_vec_b
=
b
.
dtype
().
is_scalable_or_fixed_length_vector
();
...
...
@@ -296,7 +307,7 @@ public:
PrimExpr
VisitExpr_
(
const
NotNode
*
op
)
final
{
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
if
(
a
.
same_as
(
op
->
a
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
!
(
a
);
}
...
...
@@ -337,10 +348,10 @@ public:
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
if
(
value
.
same_as
(
op
->
value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Broadcast
(
op
->
value
,
op
->
lanes
);
}
...
...
@@ -352,7 +363,7 @@ public:
PrimExpr
f
=
this
->
VisitExpr
(
op
->
false_value
);
if
(
cond
.
same_as
(
op
->
condition
)
&&
t
.
same_as
(
op
->
true_value
)
&&
f
.
same_as
(
op
->
false_value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
cond_lanes
=
cond
.
dtype
().
get_lanes_or_vscale_factor
();
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
@@ -370,7 +381,7 @@ public:
PrimExpr
VisitExpr_
(
const
CastNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
if
(
value
.
same_as
(
op
->
value
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
if
(
value
.
dtype
().
is_scalable_vector
())
{
return
Cast
(
op
->
dtype
.
with_scalable_vscale_factor
(
...
...
@@ -383,26 +394,26 @@ public:
}
PrimExpr
VisitExpr_
(
const
FloatImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
IntImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
VisitExpr_
(
const
StringImmNode
*
op
)
final
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
// Variable
PrimExpr
VisitExpr_
(
const
VarNode
*
op
)
final
{
Var
var
=
GetRef
<
Var
>
(
op
);
Var
var
=
tvm
::
ffi
::
GetRef
<
Var
>
(
op
);
if
(
var
.
same_as
(
var_
))
{
return
ramp_
;
}
auto
it
=
let_
binding
_
.
find
(
var
);
if
(
it
!=
let_
binding
_
.
end
())
{
auto
it
=
let_
var_map
_
.
find
(
var
);
if
(
it
!=
let_
var_map
_
.
end
())
{
return
it
->
second
;
}
else
{
return
std
::
move
(
var
);
...
...
@@ -413,13 +424,13 @@ public:
PrimExpr
cond
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
cond
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
PrimExpr
t
=
this
->
VisitExpr
(
op
->
args
[
1
]);
PrimExpr
f
=
this
->
VisitExpr
(
op
->
args
[
2
]);
if
(
cond
.
same_as
(
op
->
args
[
0
])
&&
t
.
same_as
(
op
->
args
[
1
])
&&
f
.
same_as
(
op
->
args
[
2
]))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
t_lanes
=
t
.
dtype
().
get_lanes_or_vscale_factor
();
int
f_lanes
=
f
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
@@ -441,7 +452,7 @@ public:
ICHECK
(
op
->
op
.
same_as
(
builtin
::
reinterpret
()));
PrimExpr
value
=
this
->
VisitExpr
(
op
->
args
[
0
]);
if
(
value
.
same_as
(
op
->
args
[
0
]))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
lanes
=
value
.
dtype
().
get_lanes_or_vscale_factor
();
if
(
value
.
dtype
().
is_scalable_vector
())
{
...
...
@@ -478,7 +489,6 @@ public:
bool
vectorizable
=
optional_op
&&
op_vectorizable_
.
get
(
optional_op
.
value
(),
false
)
&&
!
op
->
dtype
.
is_scalable_vector
();
if
(
!
vectorizable
)
{
// Cannot vectorize this op
Array
<
PrimExpr
>
new_args
;
...
...
@@ -486,12 +496,12 @@ public:
auto
new_arg
=
this
->
VisitExpr
(
arg
);
if
(
new_arg
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
need_scalarize_
=
true
;
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
new_args
.
push_back
(
new_arg
);
}
if
(
op
->
args
.
same_as
(
new_args
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Call
(
op
->
dtype
,
op
->
op
,
new_args
);
}
...
...
@@ -500,7 +510,7 @@ public:
Array
<
PrimExpr
>
new_args
=
MutateArray
(
op
->
args
,
&
lane
);
// normal code path.
if
(
op
->
args
.
same_as
(
new_args
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Call
(
op
->
dtype
.
with_lanes
(
lane
),
op
->
op
,
new_args
);
}
...
...
@@ -508,7 +518,7 @@ public:
}
// BufferLoad
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
GetRef
<
BufferLoad
>
(
op
);
auto
load
=
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
...
...
@@ -518,7 +528,6 @@ public:
if
(
!
indices
.
same_as
(
op
->
indices
))
{
BufferLoadNode
*
writer
=
load
.
CopyOnWrite
();
writer
->
indices
=
indices
;
// writer->LegalizeDType();
LegalizeBufferLoadDType
(
writer
);
}
...
...
@@ -533,21 +542,23 @@ public:
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto
it
=
let_
binding
_
.
find
(
op
->
var
);
if
(
it
!=
let_
binding
_
.
end
())
{
auto
it
=
let_
var_map
_
.
find
(
op
->
var
);
if
(
it
!=
let_
var_map
_
.
end
())
{
ICHECK
(
deep_equal_
(
it
->
second
,
value
))
<<
"Let cannot bind the same var to two different values"
;
}
if
(
value
.
dtype
().
get_lanes_or_vscale_factor
()
!=
op
->
value
.
dtype
().
get_lanes_or_vscale_factor
())
{
Var
new_var
(
op
->
var
->
name_hint
,
value
.
dtype
());
let_binding_
[
op
->
var
]
=
new_var
;
let_var_map_
[
op
->
var
]
=
new_var
;
// Record mapping from the new var to its bound value
let_value_binding_
[
new_var
]
=
value
;
return
Let
(
new_var
,
value
,
this
->
VisitExpr
(
op
->
body
));
}
else
{
let_
binding
_
[
op
->
var
]
=
op
->
var
;
let_
var_map
_
[
op
->
var
]
=
op
->
var
;
PrimExpr
body
=
this
->
VisitExpr
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
return
Let
(
op
->
var
,
value
,
body
);
}
...
...
@@ -555,7 +566,7 @@ public:
}
// BufferStore
Stmt
VisitStmt_
(
const
BufferStoreNode
*
op
)
final
{
auto
store
=
GetRef
<
BufferStore
>
(
op
);
auto
store
=
tvm
::
ffi
::
GetRef
<
BufferStore
>
(
op
);
auto
fmutate
=
[
this
](
const
PrimExpr
&
index
)
{
return
this
->
VisitExpr
(
index
);
...
...
@@ -618,11 +629,11 @@ public:
ICHECK
(
!
op
->
extent
.
dtype
().
is_scalable_or_fixed_length_vector
());
PrimExpr
extent
=
this
->
VisitExpr
(
op
->
extent
);
if
(
extent
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
if
(
extent
.
same_as
(
op
->
extent
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
return
For
(
op
->
loop_var
,
op
->
min
,
extent
,
op
->
kind
,
body
,
op
->
thread_binding
,
op
->
annotations
);
...
...
@@ -633,7 +644,7 @@ public:
ICHECK
(
!
op
->
condition
.
dtype
().
is_scalable_or_fixed_length_vector
());
PrimExpr
condition
=
this
->
VisitExpr
(
op
->
condition
);
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
Stmt
then_case
=
this
->
VisitStmt
(
op
->
then_case
);
Optional
<
Stmt
>
else_case
=
std
::
nullopt
;
...
...
@@ -642,7 +653,7 @@ public:
}
if
(
condition
.
same_as
(
op
->
condition
)
&&
then_case
.
same_as
(
op
->
then_case
)
&&
else_case
.
same_as
(
op
->
else_case
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
return
IfThenElse
(
condition
,
then_case
,
else_case
);
}
...
...
@@ -654,20 +665,23 @@ public:
// LetStmt
Stmt
VisitStmt_
(
const
LetStmtNode
*
op
)
final
{
PrimExpr
value
=
this
->
VisitExpr
(
op
->
value
);
ICHECK
(
!
let_
binding
_
.
count
(
op
->
var
))
ICHECK
(
!
let_
var_map
_
.
count
(
op
->
var
))
<<
"SSA violation, a single var is binded twice"
;
let_binding_
[
op
->
var
]
=
value
;
if
(
value
.
dtype
().
get_lanes_or_vscale_factor
()
!=
op
->
value
.
dtype
().
get_lanes_or_vscale_factor
())
{
Var
new_var
(
op
->
var
->
name_hint
,
value
.
dtype
());
let_binding_
[
op
->
var
]
=
new_var
;
let_var_map_
[
op
->
var
]
=
new_var
;
// Record mapping from the new var to its bound value
let_value_binding_
[
op
->
var
]
=
op
->
value
;
let_value_binding_
[
new_var
]
=
value
;
return
LetStmt
(
new_var
,
value
,
this
->
VisitStmt
(
op
->
body
));
}
else
{
let_binding_
[
op
->
var
]
=
op
->
var
;
let_var_map_
[
op
->
var
]
=
op
->
var
;
let_value_binding_
[
op
->
var
]
=
value
;
Stmt
body
=
this
->
VisitStmt
(
op
->
body
);
if
(
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
return
LetStmt
(
op
->
var
,
value
,
body
);
}
...
...
@@ -681,7 +695,7 @@ public:
if
(
condition
.
dtype
().
is_scalable_or_fixed_length_vector
())
{
LOG
(
WARNING
)
<<
"Cannot handle vector extent in alloc of "
<<
op
->
buffer_var
->
name_hint
;
return
Scalarize
(
GetRef
<
Stmt
>
(
op
));
return
Scalarize
(
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
}
return
StmtMutator
::
VisitStmt_
(
op
);
...
...
@@ -689,8 +703,27 @@ public:
// scalarize the statement
Stmt
Scalarize
(
Stmt
stmt
)
{
Var
idx
(
var_
->
name_hint
+
".s"
,
var_
->
dtype
);
Var
idx
(
var_
->
name_hint
+
"_s"
,
var_
->
dtype
);
// Find all Vars in stmt that are keys in let_value_binding_
std
::
unordered_set
<
Var
,
ObjectPtrHash
,
ObjectPtrEqual
>
used_let_bound_vars
;
PostOrderVisit
(
stmt
,
[
this
,
&
used_let_bound_vars
](
const
ObjectRef
&
node
)
{
if
(
const
auto
*
v
=
node
.
as
<
VarNode
>
())
{
Var
var
=
GetRef
<
Var
>
(
v
);
if
(
let_value_binding_
.
count
(
var
))
{
used_let_bound_vars
.
insert
(
var
);
}
}
});
stmt
=
Substitute
(
stmt
,
{{
var_
,
idx
}});
if
(
!
used_let_bound_vars
.
empty
())
{
for
(
const
auto
&
v
:
used_let_bound_vars
)
{
// Bind the existing var v to its value around the stmt scope
auto
new_value
=
Substitute
(
let_value_binding_
.
at
(
v
),
{{
var_
,
idx
}});
stmt
=
LetStmt
(
v
,
new_value
,
stmt
);
}
}
return
For
(
idx
,
IntImm
(
var_
->
dtype
,
0
),
var_lanes_
,
ForKind
::
kSerial
,
stmt
);
}
...
...
@@ -707,8 +740,11 @@ private:
PrimExpr
ramp_
;
// flag to mark requirement of scalarization.
bool
need_scalarize_
{
false
};
// Let binding
std
::
unordered_map
<
Var
,
PrimExpr
,
ObjectPtrHash
,
ObjectPtrEqual
>
let_binding_
;
// Let var mapping
std
::
unordered_map
<
Var
,
PrimExpr
,
ObjectPtrHash
,
ObjectPtrEqual
>
let_var_map_
;
// Let value binding: map new_var -> value
std
::
unordered_map
<
Var
,
PrimExpr
,
ObjectPtrHash
,
ObjectPtrEqual
>
let_value_binding_
;
// vectorizable property
OpAttrMap
<
TVectorizable
>
op_vectorizable_
=
Op
::
GetAttrMap
<
TVectorizable
>
(
"TVectorizable"
);
...
...
@@ -746,7 +782,7 @@ private:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
@@ -762,7 +798,7 @@ private:
PrimExpr
a
=
this
->
VisitExpr
(
op
->
a
);
PrimExpr
b
=
this
->
VisitExpr
(
op
->
b
);
if
(
a
.
same_as
(
op
->
a
)
&&
b
.
same_as
(
op
->
b
))
{
return
GetRef
<
PrimExpr
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
PrimExpr
>
(
op
);
}
else
{
int
a_lanes
=
a
.
dtype
().
get_lanes_or_vscale_factor
();
int
b_lanes
=
b
.
dtype
().
get_lanes_or_vscale_factor
();
...
...
@@ -806,7 +842,7 @@ public:
<<
" for target "
<<
Target
::
Current
();
}
ICHECK
(
is_zero
(
op
->
min
));
return
TLVectorizer
(
op
->
loop_var
,
op
->
extent
)(
op
->
body
);
return
TLVectorizer
::
Vectorize
(
op
->
loop_var
,
op
->
extent
,
op
->
body
);
}
else
{
return
StmtMutator
::
VisitStmt_
(
op
);
}
...
...
@@ -842,10 +878,10 @@ tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.VectorizeLoop"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.VectorizeLoop"
,
VectorizeLoop
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/warp_specialized_rewriter.cc
View file @
bbbf4207
...
...
@@ -159,7 +159,7 @@ public:
// Check reads from global
Block
block
(
/*iter_vars=*/
{},
/*reads=*/
{},
/*writes=*/
{},
/*name_hint=*/
""
,
/*body*/
GetRef
<
Stmt
>
(
op
));
/*body*/
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
));
auto
access
=
GetBlockReadWriteRegion
(
block
,
buffer_data_to_buffer_
);
auto
reads
=
access
[
0
];
Role
role
=
Role
::
kProducer
;
...
...
@@ -511,7 +511,7 @@ private:
annotations
.
Set
(
String
(
"stmt_group"
),
Integer
(
1
));
auto
original_node
=
(
op
->
body
).
as
<
SeqStmtNode
>
();
if
(
!
original_node
)
{
return
GetRef
<
For
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
For
>
(
op
);
}
Array
<
Stmt
>
new_body
;
int
cur_id
=
0
;
...
...
@@ -646,7 +646,7 @@ private:
if
(
role
==
Role
::
kBoth
)
{
return
StmtMutator
::
VisitStmt_
(
op
);
}
else
if
((
role
==
Role
::
kProducer
)
==
is_emitting_producer_
)
{
return
GetRef
<
Stmt
>
(
op
);
return
tvm
::
ffi
::
GetRef
<
Stmt
>
(
op
);
}
else
{
return
Evaluate
(
0
);
}
...
...
@@ -1284,7 +1284,7 @@ tvm::transform::Pass WarpSpecialized() {
return
WarpSpecializedRewriter
::
Substitute
(
f
,
disable_warp_specialized
,
disable_shuffle_elect
);
}
else
{
ObjectRef
node
=
String
(
"default"
);
auto
node
=
ffi
::
String
(
"default"
);
f
.
CopyOnWrite
()
->
body
=
AttrStmt
(
node
,
attr
::
kCustomWarpSpecialization
,
1
,
f
->
body
);
return
f
;
...
...
@@ -1293,10 +1293,10 @@ tvm::transform::Pass WarpSpecialized() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.WarpSpecialized"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.WarpSpecialized"
,
WarpSpecialized
);
}
);
}
}
// namespace tl
}
// namespace tvm
src/transform/wgmma_sync_rewriter.cc
View file @
bbbf4207
...
...
@@ -266,10 +266,10 @@ tvm::transform::Pass RewriteWgmmaSync() {
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.RewriteWgmmaSync"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
({
TVM_FFI_STATIC_INIT_BLOCK
(
)
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.RewriteWgmmaSync"
,
RewriteWgmmaSync
);
}
);
}
}
// namespace tl
}
// namespace tvm
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
View file @
bbbf4207
...
...
@@ -22,15 +22,6 @@ def tl_matmul(
b_transposed
=
True
,
k_pack
=
1
,
):
assert
in_dtype
in
[
"float16"
,
"int8"
,
],
"Currently only float16 and int8 are supported"
assert
out_dtype
in
[
"float16"
,
"float32"
,
"int32"
,
],
"Currently only float16, float32 and int32 are supported"
micro_size_x
=
micro_size_y
=
micro_size_k
=
16
...
...
@@ -190,6 +181,9 @@ def assert_tl_matmul_correctness(M,
if
in_dtype
==
"int8"
:
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
elif
in_dtype
==
"float8_e4m3fnuz"
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
else
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
...
...
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
View file @
bbbf4207
...
...
@@ -217,6 +217,9 @@ def assert_tl_matmul_correctness(M,
if
in_dtype
==
"int8"
:
A
=
torch
.
randint
(
-
128
,
127
,
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
int8
)
elif
in_dtype
==
"float8_e4m3fnuz"
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
torch
.
float16
).
to
(
getattr
(
torch
,
in_dtype
))
else
:
A
=
torch
.
rand
(
A_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
B
=
torch
.
rand
(
B_shape
,
device
=
"cuda"
,
dtype
=
getattr
(
torch
,
in_dtype
))
...
...
@@ -264,11 +267,11 @@ def assert_tl_matmul_correctness(M,
@
tilelang
.
testing
.
requires_rocm
def
test_assert_tl_matmul
():
assert_tl_matmul_correctness
(
256
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
256
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
256
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
256
,
256
,
512
,
"int8"
,
"int32"
,
b_transposed
=
False
,
accum_dtype
=
"int32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"int8"
,
"int32"
,
accum_dtype
=
"int32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
...
...
@@ -283,6 +286,21 @@ def test_assert_tl_matmul():
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
b_transposed
=
False
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_preshuffle
=
True
)
assert_tl_matmul_correctness
(
256
,
256
,
512
,
"float8_e4m3fnuz"
,
"float32"
,
k_pack
=
2
,
b_transposed
=
False
,
b_preshuffle
=
True
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/amd/test_tilelang_test_amd.py
View file @
bbbf4207
...
...
@@ -223,29 +223,26 @@ def run_gemm_rs(
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
@
tilelang
.
testing
.
requires_rocm
def
test_gemm_rs_f16f32f32_nt
():
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
False
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
True
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
False
,
"float16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
@
tilelang
.
testing
.
requires_rocm
def
test_gemm_rs_bf16f32f32_nt
():
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
False
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"float32"
,
"float32"
,
128
,
128
,
32
)
@
tilelang
.
testing
.
requires_rocm
def
test_gemm_rs_bf16bf16f32_nt
():
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
False
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
False
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
True
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
run_gemm_rs
(
1024
,
1024
,
1024
,
True
,
False
,
"bfloat16"
,
"bfloat16"
,
"float32"
,
128
,
128
,
32
)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_f16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16bf16f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/dynamic/test_tilelang_dynamic_symbolic.py
View file @
bbbf4207
...
...
@@ -514,4 +514,5 @@ def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
# tilelang.testing.main()
assert_tl_matmul_macro_correctness
(
128
,
128
,
128
,
"float16"
,
"float16"
,
"float16"
)
testing/python/issue/test_tilelang_issue_1008.py
0 → 100644
View file @
bbbf4207
import
torch
import
tilelang
import
tilelang.testing
from
tilelang
import
language
as
T
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
def
_fill_with_static_region_kernel
():
num_tokens
=
T
.
symbolic
(
'num_tokens'
)
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
'int64'
]):
# noqa: F821
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
T
.
fill
(
x
[
0
:
128
],
0
)
return
buggy_kernel
@
tilelang
.
jit
(
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
},)
def
_fill_with_dynamic_region_kernel
():
num_tokens
=
T
.
symbolic
(
'num_tokens'
)
@
T
.
prim_func
def
buggy_kernel
(
x
:
T
.
Tensor
[(
num_tokens
,),
'int64'
]):
# noqa: F821
with
T
.
Kernel
(
num_tokens
,
threads
=
128
)
as
_
:
a
,
b
=
T
.
alloc_var
(
'int'
),
T
.
alloc_var
(
'int'
)
T
.
fill
(
x
[
a
:
b
],
0
)
return
buggy_kernel
def
test_fill_with_static_region_kernel
():
kernel
=
_fill_with_static_region_kernel
()
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
kernel
(
x
)
def
test_fill_with_dynamic_region_kernel
():
kernel
=
_fill_with_dynamic_region_kernel
()
x
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
int64
,
device
=
'cuda'
)
kernel
(
x
)
if
__name__
==
'__main__'
:
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1115.py
0 → 100644
View file @
bbbf4207
import
torch
import
tilelang
import
tilelang.language
as
T
def
test_int64_address
():
@
tilelang
.
jit
def
set_cache_kernel
(
S
,
D
,
pos_ty
=
'int64'
,
dtype
=
"float32"
,
):
@
T
.
prim_func
def
main
(
pos
:
T
.
Tensor
(
[
S
,
],
pos_ty
),
# type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
value
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
cache
:
T
.
Tensor
([
S
,
D
],
dtype
),
# type: ignore
):
with
T
.
Kernel
(
S
,
threads
=
128
)
as
bx
:
slot
=
pos
[
bx
]
for
i
in
T
.
Parallel
(
D
):
cache
[
slot
,
i
]
=
value
[
bx
,
i
]
return
main
D
=
2
S
=
10
cache
=
torch
.
rand
((
S
,
D
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
value
=
torch
.
rand
((
S
,
D
),
device
=
'cuda'
,
dtype
=
torch
.
float32
)
pos_int64
=
torch
.
arange
(
S
,
device
=
'cuda'
,
dtype
=
torch
.
int64
)
pos_int32
=
torch
.
arange
(
S
,
device
=
'cuda'
,
dtype
=
torch
.
int32
)
kernel_int64
=
set_cache_kernel
(
S
,
D
,
'int64'
)
kernel_int32
=
set_cache_kernel
(
S
,
D
,
'int32'
)
kernel_int64
(
pos_int64
,
value
,
cache
)
torch
.
testing
.
assert_close
(
cache
,
value
)
kernel_int32
(
pos_int32
,
value
,
cache
)
torch
.
testing
.
assert_close
(
cache
,
value
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1198.py
0 → 100644
View file @
bbbf4207
import
tilelang.testing
import
tilelang.language
as
T
def
test_issue_1198
():
@
T
.
prim_func
def
foo
(
x
:
T
.
Buffer
([
32
,
],
"int32"
)):
pass
if
__name__
==
'__main__'
:
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1210.py
0 → 100644
View file @
bbbf4207
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
def
_make_kernel
(
M
,
N
):
dtype
=
"bfloat16"
@
T
.
prim_func
def
fwd_main
(
KV
:
T
.
Tensor
((
M
,
N
),
dtype
),
ids
:
T
.
Tensor
((
4
,),
"int32"
)):
with
T
.
Kernel
(
4
,
threads
=
1
):
A
=
T
.
alloc_shared
([
N
],
dtype
)
B
=
T
.
alloc_shared
([
N
],
dtype
)
# Regression for a bug where InjectSoftwarePipeline left the loop
# variable as a free var, causing MakePackedAPI to fail
for
i
in
T
.
Pipelined
(
4
,
num_stages
=
1
):
_id
=
ids
[
i
]
T
.
copy
(
KV
[
_id
,
:],
A
)
T
.
clear
(
B
)
return
fwd_main
def
test_make_packed_api_no_free_loop_var
():
func
=
_make_kernel
(
4
,
4
)
# Keep warp-specialization/TMA disabled to match the original repro
cfg
=
{
tilelang
.
PassConfigKey
.
TL_DISABLE_WARP_SPECIALIZED
:
True
,
tilelang
.
PassConfigKey
.
TL_DISABLE_TMA_LOWER
:
True
,
}
tilelang
.
compile
(
func
,
pass_configs
=
cfg
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/issue/test_tilelang_issue_1237.py
0 → 100644
View file @
bbbf4207
import
tilelang.testing
from
tilelang
import
language
as
T
def
test_issue_1237_dynamic_copy_extent_builds
():
# Repro from debug/1113_issues/copy_dyn.py, adapted as a unit test.
# The goal is to ensure T.copy correctly handles dynamic extents
# (e.g., src slice length vs. static dst buffer size) during prim_func building.
length
=
T
.
symbolic
(
"len"
,
dtype
=
"int32"
)
@
T
.
prim_func
def
sample_kernel
(
global_tensor
:
T
.
Tensor
[(
length
,),
"int32"
]):
# noqa: F821
with
T
.
Kernel
(
1
,
threads
=
32
):
buffer_shared
=
T
.
alloc_shared
((
1024
,),
dtype
=
"int32"
)
T
.
copy
(
global_tensor
[
0
:
length
],
buffer_shared
)
# Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute.
_
=
sample_kernel
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/jit/test_tilelang_jit_gemm_ctypes.py
View file @
bbbf4207
...
...
@@ -85,7 +85,7 @@ def run_gemm(
stramp
=
"&*(XS)"
@
tvm
.
register_func
(
"tilelang_callback_cuda_postproc"
,
override
=
True
)
@
tvm
.
register_
global_
func
(
"tilelang_callback_cuda_postproc"
,
override
=
True
)
def
tilelang_callback_cuda_postproc
(
code
,
_
):
code
=
f
"//
{
stramp
}
\n
"
+
code
return
code
...
...
@@ -407,4 +407,5 @@ def test_ctypes_dynamic_shape():
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
# tilelang.testing.main()
test_gemm_f16f16f16_nn
()
testing/python/jit/test_tilelang_jit_gemm_cython.py
View file @
bbbf4207
...
...
@@ -85,7 +85,7 @@ def run_gemm(
stramp
=
"&*(XS)"
@
tvm
.
register_func
(
"tilelang_callback_cuda_postproc"
,
override
=
True
)
@
tvm
.
register_
global_
func
(
"tilelang_callback_cuda_postproc"
,
override
=
True
)
def
tilelang_callback_cuda_postproc
(
code
,
_
):
code
=
f
"//
{
stramp
}
\n
"
+
code
return
code
...
...
testing/python/jit/test_tilelang_jit_parcompile.py
0 → 100644
View file @
bbbf4207
import
tilelang.testing
import
tilelang
import
torch
@
tilelang
.
jit
(
out_idx
=-
1
,
# create the output tensor during runtime
verbose
=
True
,
)
def
matmul_kernel_jit
(
M
,
N
,
K
,
block_M
,
block_N
,
block_K
,
trans_A
=
False
,
trans_B
=
True
,
in_dtype
=
'float16'
,
out_dtype
=
'float32'
,
accum_dtype
=
'float32'
,
num_stages
=
2
,
threads
=
128
,
):
A_shape
=
(
K
,
M
)
if
trans_A
else
(
M
,
K
)
B_shape
=
(
N
,
K
)
if
trans_B
else
(
K
,
N
)
A_shared_shape
=
(
block_K
,
block_M
)
if
trans_A
else
(
block_M
,
block_K
)
B_shared_shape
=
(
block_N
,
block_K
)
if
trans_B
else
(
block_K
,
block_N
)
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
(
A_shape
,
in_dtype
),
B
:
T
.
Tensor
(
B_shape
,
in_dtype
),
C
:
T
.
Tensor
((
M
,
N
),
out_dtype
),
):
with
T
.
Kernel
(
T
.
ceildiv
(
N
,
block_N
),
T
.
ceildiv
(
M
,
block_M
),
threads
=
threads
)
as
(
bx
,
by
):
A_shared
=
T
.
alloc_shared
(
A_shared_shape
,
in_dtype
)
B_shared
=
T
.
alloc_shared
(
B_shared_shape
,
in_dtype
)
C_local
=
T
.
alloc_fragment
((
block_M
,
block_N
),
accum_dtype
)
T
.
clear
(
C_local
)
for
k
in
T
.
Pipelined
(
T
.
ceildiv
(
K
,
block_K
),
num_stages
=
num_stages
):
if
trans_A
:
T
.
copy
(
A
[
k
*
block_K
,
by
*
block_M
],
A_shared
)
else
:
T
.
copy
(
A
[
by
*
block_M
,
k
*
block_K
],
A_shared
)
if
trans_B
:
T
.
copy
(
B
[
bx
*
block_N
,
k
*
block_K
],
B_shared
)
else
:
T
.
copy
(
B
[
k
*
block_K
,
bx
*
block_N
],
B_shared
)
T
.
gemm
(
A_shared
,
B_shared
,
C_local
,
trans_A
,
trans_B
)
T
.
copy
(
C_local
,
C
[
by
*
block_M
,
bx
*
block_N
])
return
main
def
test_par_compile
():
configs
=
[
(
1024
,
1024
,
1024
,
128
,
128
,
32
),
(
2048
,
2048
,
2048
,
256
,
256
,
64
),
(
4096
,
4096
,
4096
,
64
,
64
,
128
),
]
kernels
=
matmul_kernel_jit
.
par_compile
(
configs
)
for
(
M
,
N
,
K
,
_
,
_
,
_
),
kernel
in
zip
(
configs
,
kernels
):
A
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
).
cuda
()
B
=
torch
.
randn
(
N
,
K
,
dtype
=
torch
.
float16
).
cuda
()
ref
=
(
A
@
B
.
T
).
float
()
C
=
kernel
(
A
,
B
)
tilelang
.
testing
.
torch_assert_close
(
C
,
ref
,
rtol
=
1e-2
,
atol
=
1e-2
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment