Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
81b8c1b7
"src/include/blockwise_2d_tensor_op.cuh" did not exist on "df228b3cf514ec23dcc1decacfc1973e7f9016d9"
Unverified
Commit
81b8c1b7
authored
Dec 16, 2025
by
Kuris
Committed by
GitHub
Dec 16, 2025
Browse files
[Fix] Fix analyzer bind conflicting (#1446)
parent
869f021b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
97 additions
and
95 deletions
+97
-95
src/transform/layout_inference.cc
src/transform/layout_inference.cc
+97
-95
No files found.
src/transform/layout_inference.cc
View file @
81b8c1b7
...
...
@@ -1090,112 +1090,114 @@ private:
reducer_info
=
op
->
annotations
.
Get
(
attr
::
kReducerInfo
)
->
as
<
Map
<
Var
,
ReducerInfo
>>
()
.
value
();
if
(
!
result_
.
for_map
.
count
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
)))
{
return
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
);
}
// the analyzer will be modified in PartitionLoop and VectorizeLoop
// we need to save its state to prevent conflicted bindings
auto
saved_analyzer
=
analyzer_
->
Clone
();
For
for_node
=
Downcast
<
For
>
(
IRMutatorWithAnalyzer
::
VisitStmt_
(
op
));
if
(
result_
.
for_map
.
count
(
tvm
::
ffi
::
GetRef
<
For
>
(
op
)))
{
auto
root
=
tvm
::
ffi
::
GetRef
<
For
>
(
op
);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
bool
store_into_local
=
false
;
PostOrderVisit
(
root
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
if
(
store
->
buffer
.
scope
()
==
"local"
)
{
store_into_local
=
true
;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
auto
root
=
tvm
::
ffi
::
GetRef
<
For
>
(
op
);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
bool
store_into_local
=
false
;
PostOrderVisit
(
root
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
if
(
store
->
buffer
.
scope
()
==
"local"
)
{
store_into_local
=
true
;
}
});
// This check if for the loop that only manuplates "local" buffers,
// for i in T.Parallel(1024):
// A_local[i] = B_local[i]
// Though this might be illegal
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool
local_register_only
=
true
;
PostOrderVisit
(
root
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
if
(
store
->
buffer
.
scope
()
!=
"local"
)
{
local_register_only
=
false
;
}
}
else
if
(
const
auto
*
load
=
obj
.
as
<
BufferLoadNode
>
())
{
if
(
load
->
buffer
.
scope
()
!=
"local"
)
{
local_register_only
=
false
;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
}
});
// This check if for the loop that only manuplates "local" buffers,
// for i in T.Parallel(1024):
// A_local[i] = B_local[i]
// Though this might be illegal
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool
local_register_only
=
true
;
PostOrderVisit
(
root
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
if
(
store
->
buffer
.
scope
()
!=
"local"
)
{
local_register_only
=
false
;
}
});
}
else
if
(
const
auto
*
load
=
obj
.
as
<
BufferLoadNode
>
())
{
if
(
load
->
buffer
.
scope
()
!=
"local"
)
{
local_register_only
=
false
;
}
}
});
auto
loop_layout
=
result_
.
for_map
[
root
];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
// NOTE(lei): a bit ugly, we should rethink about this part in future.
bool
parallel_loop
=
!
skip_thread_partition_
&&
!
local_register_only
&&
!
store_into_local
;
auto
loop_layout
=
result_
.
for_map
[
root
];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
// NOTE(lei): a bit ugly, we should rethink about this part in future.
bool
parallel_loop
=
!
skip_thread_partition_
&&
!
local_register_only
&&
!
store_into_local
;
if
(
parallel_loop
)
{
for_node
=
PartitionLoop
(
for_node
,
thread_var_
->
var
,
analyzer_
,
loop_layout
);
if
(
parallel_loop
)
{
for_node
=
PartitionLoop
(
for_node
,
thread_var_
->
var
,
analyzer_
,
loop_layout
);
}
// If none thread bindings are provided, partition the loop
bool
has_non_local
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
load
=
obj
.
as
<
BufferLoadNode
>
())
{
String
scope
=
load
->
buffer
.
scope
();
if
(
scope
!=
"local"
&&
scope
!=
"local.fragment"
)
{
has_non_local
=
true
;
}
}
else
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
String
scope
=
store
->
buffer
.
scope
();
if
(
scope
!=
"local"
&&
scope
!=
"local.fragment"
)
{
has_non_local
=
true
;
}
}
// If none thread bindings are provided, partition the loop
bool
has_non_local
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
load
=
obj
.
as
<
BufferLoadNode
>
())
{
String
scope
=
load
->
buffer
.
scope
();
if
(
scope
!=
"local"
&&
scope
!=
"local.fragment"
)
{
has_non_local
=
true
;
}
}
else
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
String
scope
=
store
->
buffer
.
scope
();
if
(
scope
!=
"local"
&&
scope
!=
"local.fragment"
)
{
has_non_local
=
true
;
}
});
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization
bool
has_reducer
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
!
has_reducer
)
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
has_reducer
=
reducer_info
.
count
(
store
->
buffer
->
data
)
!=
0
;
}
});
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization
bool
has_reducer
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
!
has_reducer
)
if
(
const
auto
*
store
=
obj
.
as
<
BufferStoreNode
>
())
{
has_reducer
=
reducer_info
.
count
(
store
->
buffer
->
data
)
!=
0
;
}
});
// If a cast operation exists, vectorization may still be required
bool
has_cast_operations
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
cast
=
obj
.
as
<
CastNode
>
())
{
// Check if this is a non-reducer store with Cast operation
DataType
src_type
=
cast
->
value
.
dtype
();
DataType
dst_type
=
cast
->
dtype
;
bool
src_ok
=
src_type
.
is_float
()
||
src_type
.
is_bfloat
()
||
src_type
.
is_float8_e4m3
()
||
src_type
.
is_float8_e5m2
();
bool
dst_ok
=
dst_type
.
is_float
()
||
dst_type
.
is_bfloat
()
||
dst_type
.
is_float8_e4m3
()
||
dst_type
.
is_float8_e5m2
();
if
(
src_ok
&&
dst_ok
&&
TargetIsCuda
(
Target
::
Current
()))
{
has_cast_operations
=
true
;
}
});
// If a cast operation exists, vectorization may still be required
bool
has_cast_operations
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
cast
=
obj
.
as
<
CastNode
>
())
{
// Check if this is a non-reducer store with Cast operation
DataType
src_type
=
cast
->
value
.
dtype
();
DataType
dst_type
=
cast
->
dtype
;
bool
src_ok
=
src_type
.
is_float
()
||
src_type
.
is_bfloat
()
||
src_type
.
is_float8_e4m3
()
||
src_type
.
is_float8_e5m2
();
bool
dst_ok
=
dst_type
.
is_float
()
||
dst_type
.
is_bfloat
()
||
dst_type
.
is_float8_e4m3
()
||
dst_type
.
is_float8_e5m2
();
if
(
src_ok
&&
dst_ok
&&
TargetIsCuda
(
Target
::
Current
()))
{
has_cast_operations
=
true
;
}
});
if
((
has_non_local
||
has_cast_operations
)
&&
!
has_reducer
)
{
for_node
=
VectorizeLoop
(
for_node
,
analyzer_
);
}
});
if
(
result_
.
predicate_map
.
count
(
root
)
&&
parallel_loop
)
{
return
IfThenElse
(
result_
.
predicate_map
[
root
],
for_node
);
}
else
{
return
for_node
;
}
if
((
has_non_local
||
has_cast_operations
)
&&
!
has_reducer
)
{
for_node
=
VectorizeLoop
(
for_node
,
saved_analyzer
.
get
());
}
if
(
result_
.
predicate_map
.
count
(
root
)
&&
parallel_loop
)
{
return
IfThenElse
(
result_
.
predicate_map
[
root
],
for_node
);
}
else
{
return
for_node
;
}
return
for_node
;
}
Stmt
VisitStmt_
(
const
AttrStmtNode
*
op
)
final
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment