Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
176 additions
and
316 deletions
+176
-316
tilelang/carver/arch/metal.py
tilelang/carver/arch/metal.py
+2
-3
tilelang/carver/common_schedules.py
tilelang/carver/common_schedules.py
+1
-0
tilelang/carver/matmul_analysis.py
tilelang/carver/matmul_analysis.py
+29
-50
tilelang/carver/roller/bestfit.py
tilelang/carver/roller/bestfit.py
+2
-6
tilelang/carver/roller/hint.py
tilelang/carver/roller/hint.py
+3
-2
tilelang/carver/roller/node.py
tilelang/carver/roller/node.py
+23
-42
tilelang/carver/roller/policy/default.py
tilelang/carver/roller/policy/default.py
+28
-56
tilelang/carver/roller/policy/tensorcore.py
tilelang/carver/roller/policy/tensorcore.py
+16
-39
tilelang/carver/roller/rasterization.py
tilelang/carver/roller/rasterization.py
+0
-2
tilelang/carver/roller/shape_inference/common.py
tilelang/carver/roller/shape_inference/common.py
+1
-4
tilelang/carver/roller/shape_inference/tir.py
tilelang/carver/roller/shape_inference/tir.py
+9
-32
tilelang/carver/template/base.py
tilelang/carver/template/base.py
+9
-4
tilelang/carver/template/conv.py
tilelang/carver/template/conv.py
+17
-8
tilelang/carver/template/flashattention.py
tilelang/carver/template/flashattention.py
+1
-5
tilelang/carver/template/gemv.py
tilelang/carver/template/gemv.py
+3
-6
tilelang/carver/template/general_reduce.py
tilelang/carver/template/general_reduce.py
+4
-6
tilelang/carver/template/matmul.py
tilelang/carver/template/matmul.py
+3
-6
tilelang/carver/utils.py
tilelang/carver/utils.py
+10
-17
tilelang/contrib/cc.py
tilelang/contrib/cc.py
+12
-22
tilelang/contrib/dlpack.py
tilelang/contrib/dlpack.py
+3
-6
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/carver/arch/metal.py
View file @
29051439
...
@@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool:
...
@@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool:
class
METAL
(
TileDevice
):
class
METAL
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
if
isinstance
(
target
,
str
):
target
=
Target
(
target
)
target
=
Target
(
target
)
...
@@ -16,6 +15,6 @@ class METAL(TileDevice):
...
@@ -16,6 +15,6 @@ class METAL(TileDevice):
__all__
=
[
__all__
=
[
'
is_metal_arch
'
,
"
is_metal_arch
"
,
'
METAL
'
,
"
METAL
"
,
]
]
tilelang/carver/common_schedules.py
View file @
29051439
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
# Modifications Copyright (c) Microsoft.
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
"""Common schedule strategies for TIR."""
from
typing
import
Callable
from
typing
import
Callable
from
tvm
import
tir
from
tvm
import
tir
...
...
tilelang/carver/matmul_analysis.py
View file @
29051439
# pylint: disable=missing-docstring, invalid-name
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
"""A GEMM schedule rule for GPU operators."""
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
...
@@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block
...
@@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block
return
block
return
block
def
find_arg_idx_from_buffer_chain
(
sch
:
tir
.
Schedule
,
main_block
:
tir
.
schedule
.
BlockRV
,
def
find_arg_idx_from_buffer_chain
(
sch
:
tir
.
Schedule
,
main_block
:
tir
.
schedule
.
BlockRV
,
buffer
:
tir
.
Buffer
)
->
int
:
buffer
:
tir
.
Buffer
)
->
int
:
"""traverse to find the arg index from the buffer"""
"""traverse to find the arg index from the buffer"""
producers
=
sch
.
get_producers
(
main_block
)
producers
=
sch
.
get_producers
(
main_block
)
...
@@ -226,9 +226,7 @@ def make_iter_fusion_index_map(
...
@@ -226,9 +226,7 @@ def make_iter_fusion_index_map(
else
:
else
:
fused_iters
[
trait
.
kind
]
=
v_i
fused_iters
[
trait
.
kind
]
=
v_i
final_indices
:
list
[
tir
.
PrimExpr
]
=
[
final_indices
:
list
[
tir
.
PrimExpr
]
=
[
fused_iters
.
get
(
kind
,
tir
.
IntImm
(
traits
[
0
].
extent
.
dtype
,
0
))
for
kind
in
kind_order
]
fused_iters
.
get
(
kind
,
tir
.
IntImm
(
traits
[
0
].
extent
.
dtype
,
0
))
for
kind
in
kind_order
]
return
tir
.
IndexMap
(
input_iters
,
final_indices
,
None
)
return
tir
.
IndexMap
(
input_iters
,
final_indices
,
None
)
...
@@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
...
@@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
return
A_traits
,
B_traits
,
C_traits
,
block_traits
return
A_traits
,
B_traits
,
C_traits
,
block_traits
def
get_index_map
(
block
:
tir
.
Block
,
def
get_index_map
(
block
:
tir
.
Block
,
layout
:
list
[
str
]
|
None
=
None
)
->
tuple
[
tir
.
IndexMap
,
...]
|
None
:
layout
:
list
[
str
]
|
None
=
None
)
->
tuple
[
tir
.
IndexMap
,
...]
|
None
:
"""Get index maps for the block
"""Get index maps for the block
Parameters
Parameters
...
@@ -343,10 +340,7 @@ def get_index_map(block: tir.Block,
...
@@ -343,10 +340,7 @@ def get_index_map(block: tir.Block,
return
axes
return
axes
def
is_common_reduce
(
var
:
Var
)
->
bool
:
def
is_common_reduce
(
var
:
Var
)
->
bool
:
for
iter_var
in
block
.
iter_vars
:
return
any
(
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
for
iter_var
in
block
.
iter_vars
)
if
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
:
return
True
return
False
def
has_common_reduce
(
var
:
Var
)
->
bool
:
def
has_common_reduce
(
var
:
Var
)
->
bool
:
vars
=
collect_vars_from_expr
(
var
)
vars
=
collect_vars_from_expr
(
var
)
...
@@ -384,17 +378,17 @@ def get_index_map(block: tir.Block,
...
@@ -384,17 +378,17 @@ def get_index_map(block: tir.Block,
if
kind
==
"C"
:
if
kind
==
"C"
:
return
[
IterKind
.
kIter_S
,
primary_iter
,
secondary_iter
]
return
[
IterKind
.
kIter_S
,
primary_iter
,
secondary_iter
]
else
:
else
:
return
([
IterKind
.
kIter_S
,
spatial_iter
,
reduction_iter
]
if
check_last_trait
(
region
)
return
(
else
[
IterKind
.
kIter_S
,
reduction_iter
,
spatial_iter
])
[
IterKind
.
kIter_S
,
spatial_iter
,
reduction_iter
]
if
check_last_trait
(
region
)
else
[
IterKind
.
kIter_S
,
reduction_iter
,
spatial_iter
]
)
else
:
else
:
raise
ValueError
(
f
"Unknown layout
{
layout
}
"
)
raise
ValueError
(
f
"Unknown layout
{
layout
}
"
)
A_index_map
=
make_iter_fusion_index_map
(
A_index_map
=
make_iter_fusion_index_map
(
A_traits
,
infer_layout
(
layout
[
0
],
block
.
reads
[
0
].
region
,
kind
=
"A"
))
A_traits
,
infer_layout
(
layout
[
0
],
block
.
reads
[
0
].
region
,
kind
=
"A"
))
B_index_map
=
make_iter_fusion_index_map
(
B_traits
,
infer_layout
(
layout
[
1
],
block
.
reads
[
1
].
region
,
kind
=
"B"
))
B_index_map
=
make_iter_fusion_index_map
(
C_index_map
=
make_iter_fusion_index_map
(
C_traits
,
infer_layout
(
layout
[
2
],
block
.
writes
[
0
].
region
,
kind
=
"C"
))
B_traits
,
infer_layout
(
layout
[
1
],
block
.
reads
[
1
].
region
,
kind
=
"B"
))
C_index_map
=
make_iter_fusion_index_map
(
C_traits
,
infer_layout
(
layout
[
2
],
block
.
writes
[
0
].
region
,
kind
=
"C"
))
matmul_index_map
=
make_iter_fusion_index_map
(
matmul_index_map
=
make_iter_fusion_index_map
(
block_traits
,
block_traits
,
...
@@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None:
...
@@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None:
has_uint_input
=
any
(
"uint"
in
str
(
region
.
buffer
.
dtype
)
for
region
in
block_stmt
.
reads
)
has_uint_input
=
any
(
"uint"
in
str
(
region
.
buffer
.
dtype
)
for
region
in
block_stmt
.
reads
)
if
not
has_uint_input
:
if
not
has_uint_input
:
return
False
return
False
return
not
(
len
(
block_stmt
.
writes
)
!=
1
or
return
not
(
len
(
block_stmt
.
writes
)
!=
1
or
"float"
not
in
str
(
block_stmt
.
writes
[
0
].
buffer
.
dtype
))
"float"
not
in
str
(
block_stmt
.
writes
[
0
].
buffer
.
dtype
))
dequantize_blocks
=
[
block
for
block
in
blocks
if
is_dequantize
(
block
)]
dequantize_blocks
=
[
block
for
block
in
blocks
if
is_dequantize
(
block
)]
return
dequantize_blocks
[
0
]
if
len
(
dequantize_blocks
)
==
1
else
None
return
dequantize_blocks
[
0
]
if
len
(
dequantize_blocks
)
==
1
else
None
...
@@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
...
@@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
return
None
return
None
axes
.
extend
(
undefined_vars
(
r
.
min
))
axes
.
extend
(
undefined_vars
(
r
.
min
))
# remove trivial axis
# remove trivial axis
trivial_vars
=
set
(
trivial_vars
=
set
(
iter_var
.
var
for
iter_var
in
block_stmt
.
iter_vars
if
_is_one
(
iter_var
.
dom
.
extent
))
iter_var
.
var
for
iter_var
in
block_stmt
.
iter_vars
if
_is_one
(
iter_var
.
dom
.
extent
))
axes
=
[
axis
for
axis
in
axes
if
axis
not
in
trivial_vars
]
axes
=
[
axis
for
axis
in
axes
if
axis
not
in
trivial_vars
]
# remove duplicate axis
# remove duplicate axis
axes
=
[
var
for
i
,
var
in
enumerate
(
axes
)
if
i
==
0
or
var
!=
axes
[
i
-
1
]]
axes
=
[
var
for
i
,
var
in
enumerate
(
axes
)
if
i
==
0
or
var
!=
axes
[
i
-
1
]]
...
@@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
...
@@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
lhs_access_vars
=
get_access_vars
(
block_stmt
.
reads
[
0
].
region
)[
-
2
:]
lhs_access_vars
=
get_access_vars
(
block_stmt
.
reads
[
0
].
region
)[
-
2
:]
rhs_access_vars
=
get_access_vars
(
block_stmt
.
writes
[
0
].
region
)[
-
2
:]
rhs_access_vars
=
get_access_vars
(
block_stmt
.
writes
[
0
].
region
)[
-
2
:]
is_identity
=
list
(
lhs_access_vars
)
==
list
(
rhs_access_vars
)
is_identity
=
list
(
lhs_access_vars
)
==
list
(
rhs_access_vars
)
is_transpose
=
list
(
lhs_access_vars
)
!=
list
(
rhs_access_vars
)
and
set
(
lhs_access_vars
)
==
set
(
is_transpose
=
list
(
lhs_access_vars
)
!=
list
(
rhs_access_vars
)
and
set
(
lhs_access_vars
)
==
set
(
rhs_access_vars
)
rhs_access_vars
)
return
is_identity
,
is_transpose
return
is_identity
,
is_transpose
...
@@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]
...
@@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]
return
result_blocks
return
result_blocks
def
normalize_to_matmul
(
sch
:
tir
.
Schedule
,
def
normalize_to_matmul
(
sch
:
tir
.
Schedule
,
main_block
:
BlockRV
,
layout
:
list
[
str
]
|
None
=
None
)
->
tir
.
Schedule
|
None
:
main_block
:
BlockRV
,
layout
:
list
[
str
]
|
None
=
None
)
->
tir
.
Schedule
|
None
:
if
layout
is
None
:
if
layout
is
None
:
layout
=
[
"n"
,
"t"
,
"n"
]
layout
=
[
"n"
,
"t"
,
"n"
]
block_stmt
=
sch
.
get
(
main_block
)
block_stmt
=
sch
.
get
(
main_block
)
...
@@ -526,7 +515,7 @@ def get_tensorized_func_and_tags(
...
@@ -526,7 +515,7 @@ def get_tensorized_func_and_tags(
allow_gemv
:
bool
=
False
,
allow_gemv
:
bool
=
False
,
)
->
tuple
[
tir
.
PrimFunc
,
dict
[
str
,
list
[
int
]
|
int
]]:
)
->
tuple
[
tir
.
PrimFunc
,
dict
[
str
,
list
[
int
]
|
int
]]:
"""
"""
transform function to matmul if necessary (e.g. transform conv2d with im2col)
transform function to matmul if necessary (e.g. transform conv2d with im2col)
"""
"""
if
layout
is
None
:
if
layout
is
None
:
layout
=
[
"a"
,
"a"
,
"a"
]
layout
=
[
"a"
,
"a"
,
"a"
]
...
@@ -543,10 +532,7 @@ def get_tensorized_func_and_tags(
...
@@ -543,10 +532,7 @@ def get_tensorized_func_and_tags(
conditions
=
[]
conditions
=
[]
conditions
.
append
(
len
(
block_stmt
.
reads
)
==
2
)
conditions
.
append
(
len
(
block_stmt
.
reads
)
==
2
)
conditions
.
append
(
len
(
block_stmt
.
writes
)
==
1
)
conditions
.
append
(
len
(
block_stmt
.
writes
)
==
1
)
conditions
.
append
(
conditions
.
append
(
len
(
collect_block_iter_vars_used_in_access_region
(
block_stmt
,
block_stmt
.
writes
[
0
].
region
))
>
0
)
len
(
collect_block_iter_vars_used_in_access_region
(
block_stmt
,
block_stmt
.
writes
[
0
].
region
))
>
0
)
return
all
(
conditions
)
return
all
(
conditions
)
# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
...
@@ -592,10 +578,7 @@ def get_tensorized_func_and_tags(
...
@@ -592,10 +578,7 @@ def get_tensorized_func_and_tags(
return
axes
return
axes
def
is_common_reduce
(
var
:
Var
)
->
bool
:
def
is_common_reduce
(
var
:
Var
)
->
bool
:
for
iter_var
in
block_stmt
.
iter_vars
:
return
any
(
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
for
iter_var
in
block_stmt
.
iter_vars
)
if
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
:
return
True
return
False
def
has_common_reduce
(
var
:
Var
)
->
bool
:
def
has_common_reduce
(
var
:
Var
)
->
bool
:
vars
=
collect_vars_from_expr
(
var
)
vars
=
collect_vars_from_expr
(
var
)
...
@@ -626,7 +609,7 @@ def get_tensorized_func_and_tags(
...
@@ -626,7 +609,7 @@ def get_tensorized_func_and_tags(
# When the func is a dequantize like ops, we should consider the M
# When the func is a dequantize like ops, we should consider the M
require_block_reduce
=
False
require_block_reduce
=
False
# And we only support float16 for now
# And we only support float16 for now
if
(
hasattr
(
func
.
attrs
,
"dequantize_info"
)
and
in_dtype
in
[
"bfloat16"
,
"float16"
]
)
:
if
hasattr
(
func
.
attrs
,
"dequantize_info"
)
and
in_dtype
in
[
"bfloat16"
,
"float16"
]:
for
arg
in
func
.
params
:
for
arg
in
func
.
params
:
inp_shape
=
func
.
buffer_map
[
arg
].
shape
inp_shape
=
func
.
buffer_map
[
arg
].
shape
M
=
inp_shape
[
0
]
M
=
inp_shape
[
0
]
...
@@ -645,9 +628,7 @@ def get_tensorized_func_and_tags(
...
@@ -645,9 +628,7 @@ def get_tensorized_func_and_tags(
if
target
.
kind
.
name
==
"cuda"
and
check_sm_version
(
target
.
arch
)
>=
70
:
if
target
.
kind
.
name
==
"cuda"
and
check_sm_version
(
target
.
arch
)
>=
70
:
in_dtype
,
out_dtype
=
get_in_out_dtypes
(
block_stmt
)
in_dtype
,
out_dtype
=
get_in_out_dtypes
(
block_stmt
)
if
not
is_tensorcore_supported_precision
(
in_dtype
,
out_dtype
,
arch
=
get_arch
(
target
)):
if
not
is_tensorcore_supported_precision
(
in_dtype
,
out_dtype
,
arch
=
get_arch
(
target
)):
logger
.
debug
(
logger
.
debug
(
f
"The input and output dtype (
{
in_dtype
}
,
{
out_dtype
}
)is not supported by tensorcore"
)
f
"The input and output dtype (
{
in_dtype
}
,
{
out_dtype
}
)is not supported by tensorcore"
)
return
func
,
None
return
func
,
None
# reindex and transform functions
# reindex and transform functions
...
@@ -676,7 +657,7 @@ def get_tensorized_func_and_tags(
...
@@ -676,7 +657,7 @@ def get_tensorized_func_and_tags(
else
:
else
:
raise
ValueError
(
f
"Unknown IterVar type
{
iter_type
}
"
)
raise
ValueError
(
f
"Unknown IterVar type
{
iter_type
}
"
)
if
(
isinstance
(
extent
,
tir
.
expr
.
IntImm
)
and
extent
.
value
<
minimal_tensorize_threshold
)
:
if
isinstance
(
extent
,
tir
.
expr
.
IntImm
)
and
extent
.
value
<
minimal_tensorize_threshold
:
return
func
,
None
return
func
,
None
tags
=
analysis_tensorcore_tags
(
sch
,
main_block
,
target
)
tags
=
analysis_tensorcore_tags
(
sch
,
main_block
,
target
)
return
sch
.
mod
[
"main"
],
tags
return
sch
.
mod
[
"main"
],
tags
...
@@ -686,8 +667,10 @@ def get_tensorized_func_and_tags(
...
@@ -686,8 +667,10 @@ def get_tensorized_func_and_tags(
def
get_propagate_map
(
trans
:
bool
=
True
,
dtype
=
"float16"
,
matrix_name
=
"A"
,
index_dtype
=
"int32"
):
def
get_propagate_map
(
trans
:
bool
=
True
,
dtype
=
"float16"
,
matrix_name
=
"A"
,
index_dtype
=
"int32"
):
from
bitblas.tl.mma_layout
import
(
# pylint: disable=import-outside-toplevel
from
bitblas.tl.mma_layout
import
(
# pylint: disable=import-outside-toplevel
ldmatrix_32x8_to_shared_16x16_layout
,
ldmatrix_trans_32x8_to_shared_16x16_layout
,
ldmatrix_32x8_to_shared_16x16_layout
,
ldmatrix_32x16_to_shared_16x32_layout_a
,
ldmatrix_32x16_to_shared_16x32_layout_b
,
ldmatrix_trans_32x8_to_shared_16x16_layout
,
ldmatrix_32x16_to_shared_16x32_layout_a
,
ldmatrix_32x16_to_shared_16x32_layout_b
,
)
)
assert
dtype
in
[
assert
dtype
in
[
...
@@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
...
@@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
return
ldmatrix_layout
(
thread_id
,
local_id
)
return
ldmatrix_layout
(
thread_id
,
local_id
)
if
dtype
in
[
"bfloat16"
,
"float16"
]:
if
dtype
in
[
"bfloat16"
,
"float16"
]:
ldmatrix_index_map
=
(
ldmatrix_index_map
=
ldmatrix_trans_permutation_16x16_32x8_16x16
if
trans
else
ldmatrix_permutation_16x16_32x8_16x16
ldmatrix_trans_permutation_16x16_32x8_16x16
if
trans
else
ldmatrix_permutation_16x16_32x8_16x16
)
else
:
else
:
ldmatrix_index_map
=
ldmatrix_permutation_16x32_32x16_32x16
ldmatrix_index_map
=
ldmatrix_permutation_16x32_32x16_32x16
...
@@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
...
@@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
# Ladder weight propagation, which can be used to avoid the ldmatrix
# Ladder weight propagation, which can be used to avoid the ldmatrix
# Instructions.
# Instructions.
def
get_ladder_stage3_map
(
dtype
=
"float16"
,
index_dtype
=
"int32"
):
def
get_ladder_stage3_map
(
dtype
=
"float16"
,
index_dtype
=
"int32"
):
def
shared_32x8_to_mma_32x8_layout
(
i
,
j
):
def
shared_32x8_to_mma_32x8_layout
(
i
,
j
):
thread_id
=
(
i
%
8
)
*
4
+
(
j
//
2
)
thread_id
=
(
i
%
8
)
*
4
+
(
j
//
2
)
local_id
=
(
i
//
8
)
*
2
+
(
j
%
2
)
local_id
=
(
i
//
8
)
*
2
+
(
j
%
2
)
...
@@ -837,8 +817,7 @@ def layout_propagate_chain(
...
@@ -837,8 +817,7 @@ def layout_propagate_chain(
scaling_factor
=
1
scaling_factor
=
1
for
i
,
j
in
zip
(
write
.
buffer
.
shape
,
read
.
buffer
.
shape
):
for
i
,
j
in
zip
(
write
.
buffer
.
shape
,
read
.
buffer
.
shape
):
scaling_factor
*=
i
//
j
scaling_factor
*=
i
//
j
final_indices
=
list
(
final_indices
=
list
(
index_map
.
map_indices
(
tmp_index_map
.
map_indices
(
write_indices
)))
index_map
.
map_indices
(
tmp_index_map
.
map_indices
(
write_indices
)))
final_indices
[
-
1
]
=
final_indices
[
-
1
]
//
scaling_factor
final_indices
[
-
1
]
=
final_indices
[
-
1
]
//
scaling_factor
index_map
=
IndexMap
(
index_map
=
IndexMap
(
write_indices
,
write_indices
,
...
...
tilelang/carver/roller/bestfit.py
View file @
29051439
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
class
Block
:
class
Block
:
def
__init__
(
self
,
start
,
end
,
is_free
):
def
__init__
(
self
,
start
,
end
,
is_free
):
self
.
start
=
start
self
.
start
=
start
self
.
end
=
end
self
.
end
=
end
...
@@ -21,7 +20,6 @@ class Block:
...
@@ -21,7 +20,6 @@ class Block:
class
BestFit
:
class
BestFit
:
def
__init__
(
self
,
align
=
32
):
def
__init__
(
self
,
align
=
32
):
self
.
limit
=
0
self
.
limit
=
0
self
.
list
=
[]
self
.
list
=
[]
...
@@ -31,16 +29,14 @@ class BestFit:
...
@@ -31,16 +29,14 @@ class BestFit:
size
=
(
size
+
self
.
align
-
1
)
//
self
.
align
*
self
.
align
size
=
(
size
+
self
.
align
-
1
)
//
self
.
align
*
self
.
align
found
=
None
found
=
None
for
block
in
self
.
list
:
for
block
in
self
.
list
:
if
block
.
is_free
and
block
.
size
()
>=
size
and
(
not
found
or
if
block
.
is_free
and
block
.
size
()
>=
size
and
(
not
found
or
found
.
size
()
>
block
.
size
()):
found
.
size
()
>
block
.
size
()):
found
=
block
found
=
block
if
found
:
if
found
:
found
.
is_free
=
False
found
.
is_free
=
False
remain
=
found
.
size
()
-
size
remain
=
found
.
size
()
-
size
if
remain
!=
0
:
if
remain
!=
0
:
found
.
end
-=
remain
found
.
end
-=
remain
self
.
list
.
insert
(
self
.
list
.
insert
(
self
.
list
.
index
(
found
)
+
1
,
Block
(
found
.
end
,
found
.
end
+
remain
,
True
))
self
.
list
.
index
(
found
)
+
1
,
Block
(
found
.
end
,
found
.
end
+
remain
,
True
))
return
found
return
found
elif
len
(
self
.
list
)
>
0
and
self
.
list
[
-
1
].
is_free
:
elif
len
(
self
.
list
)
>
0
and
self
.
list
[
-
1
].
is_free
:
add
=
size
-
self
.
list
[
-
1
].
size
()
add
=
size
-
self
.
list
[
-
1
].
size
()
...
...
tilelang/carver/roller/hint.py
View file @
29051439
"""Hint definition for schedule"""
"""Hint definition for schedule"""
from
tvm
import
DataType
from
tvm
import
DataType
from
.
import
PrimFuncNode
from
.
import
PrimFuncNode
import
numpy
as
np
import
numpy
as
np
...
@@ -60,7 +61,7 @@ class Stride:
...
@@ -60,7 +61,7 @@ class Stride:
strided_elem
=
original_shape
strided_elem
=
original_shape
else
:
else
:
assert
self
.
ax
<
len
(
shape
)
assert
self
.
ax
<
len
(
shape
)
strided_elem
=
np
.
prod
(
shape
[
0
:
self
.
ax
+
1
])
*
self
.
stride
strided_elem
=
np
.
prod
(
shape
[
0
:
self
.
ax
+
1
])
*
self
.
stride
assert
strided_elem
>=
original_shape
assert
strided_elem
>=
original_shape
return
int
(
strided_elem
)
return
int
(
strided_elem
)
...
@@ -217,7 +218,7 @@ class Hint:
...
@@ -217,7 +218,7 @@ class Hint:
return
dic
return
dic
@
classmethod
@
classmethod
def
from_dict
(
cls
,
dic
:
dict
)
->
'
Hint
'
:
def
from_dict
(
cls
,
dic
:
dict
)
->
"
Hint
"
:
hint
=
cls
()
hint
=
cls
()
for
k
,
v
in
dic
.
items
():
for
k
,
v
in
dic
.
items
():
setattr
(
hint
,
k
,
v
)
setattr
(
hint
,
k
,
v
)
...
...
tilelang/carver/roller/node.py
View file @
29051439
"""PrimFunc Wrapper and Block information Analaysis"""
"""PrimFunc Wrapper and Block information Analaysis"""
from
__future__
import
annotations
from
__future__
import
annotations
import
tvm
import
tvm
...
@@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func):
...
@@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func):
class
BlockAnalyzer
:
class
BlockAnalyzer
:
def
__init__
(
self
,
sch
)
->
None
:
def
__init__
(
self
,
sch
)
->
None
:
self
.
sch
:
tir
.
Schedule
=
sch
self
.
sch
:
tir
.
Schedule
=
sch
self
.
block_infos
:
list
[
BlockInfo
]
=
normalize_prim_func
(
self
.
sch
)
self
.
block_infos
:
list
[
BlockInfo
]
=
normalize_prim_func
(
self
.
sch
)
...
@@ -92,7 +92,6 @@ class Edge:
...
@@ -92,7 +92,6 @@ class Edge:
class
Node
:
class
Node
:
def
__init__
(
self
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"Node"
)
->
None
:
def
__init__
(
self
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"Node"
)
->
None
:
self
.
name
=
name
self
.
name
=
name
if
tags
is
None
:
if
tags
is
None
:
...
@@ -177,7 +176,6 @@ class Node:
...
@@ -177,7 +176,6 @@ class Node:
class
PlaceHolderNode
(
Node
):
class
PlaceHolderNode
(
Node
):
def
__init__
(
self
,
name
=
""
):
def
__init__
(
self
,
name
=
""
):
super
().
__init__
(
name
=
"PlaceHolder_"
+
name
)
super
().
__init__
(
name
=
"PlaceHolder_"
+
name
)
...
@@ -189,11 +187,7 @@ class PlaceHolderNode(Node):
...
@@ -189,11 +187,7 @@ class PlaceHolderNode(Node):
class
PrimFuncNode
(
Node
):
class
PrimFuncNode
(
Node
):
def
__init__
(
self
,
prim_func
:
PrimFunc
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
)
->
None
:
def
__init__
(
self
,
prim_func
:
PrimFunc
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
)
->
None
:
super
().
__init__
(
tags
,
name
=
name
)
super
().
__init__
(
tags
,
name
=
name
)
self
.
prim_func
=
self
.
_specialize_func
(
prim_func
)
self
.
prim_func
=
self
.
_specialize_func
(
prim_func
)
self
.
sch
:
tir
.
Schedule
=
tir
.
Schedule
(
self
.
prim_func
)
self
.
sch
:
tir
.
Schedule
=
tir
.
Schedule
(
self
.
prim_func
)
...
@@ -227,7 +221,7 @@ class PrimFuncNode(Node):
...
@@ -227,7 +221,7 @@ class PrimFuncNode(Node):
for
dst_id
,
n
in
enumerate
(
inputs
):
for
dst_id
,
n
in
enumerate
(
inputs
):
if
isinstance
(
n
,
Node
):
if
isinstance
(
n
,
Node
):
n
=
(
n
,
0
)
n
=
(
n
,
0
)
assert
(
len
(
n
)
==
2
)
assert
len
(
n
)
==
2
src_node
,
src_id
=
n
[
0
],
n
[
1
]
src_node
,
src_id
=
n
[
0
],
n
[
1
]
edge
=
Edge
(
src_node
,
self
,
src_id
,
dst_id
)
edge
=
Edge
(
src_node
,
self
,
src_id
,
dst_id
)
self
.
_in_edges
.
append
(
edge
)
self
.
_in_edges
.
append
(
edge
)
...
@@ -338,9 +332,8 @@ class PrimFuncNode(Node):
...
@@ -338,9 +332,8 @@ class PrimFuncNode(Node):
if
rstep
is
None
:
if
rstep
is
None
:
rstep
=
{}
rstep
=
{}
shape
=
{
shape
=
{
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
].
name
:
[
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
].
name
:
[
tvm
.
arith
.
ConstIntBound
(
0
,
val
-
1
)
for
val
in
tile
]
tvm
.
arith
.
ConstIntBound
(
0
,
val
-
1
)
for
val
in
tile
for
block
in
self
.
schedule_stages
]
for
block
in
self
.
schedule_stages
}
}
return
self
.
ana
.
infer
(
shape
,
rstep
,
targets
)
return
self
.
ana
.
infer
(
shape
,
rstep
,
targets
)
...
@@ -356,10 +349,7 @@ class PrimFuncNode(Node):
...
@@ -356,10 +349,7 @@ class PrimFuncNode(Node):
results
.
append
(
shapes
[
arg
.
name
])
results
.
append
(
shapes
[
arg
.
name
])
continue
continue
# should not exceed original shape
# should not exceed original shape
trimmed_shape
=
[
trimmed_shape
=
[
self
.
extent_wrapper
(
i
)
for
i
in
list
(
map
(
min
,
zip
(
shapes
[
arg
.
name
],
self
.
input_buffers
[
i
].
shape
)))]
self
.
extent_wrapper
(
i
)
for
i
in
list
(
map
(
min
,
zip
(
shapes
[
arg
.
name
],
self
.
input_buffers
[
i
].
shape
)))
]
results
.
append
(
trimmed_shape
)
results
.
append
(
trimmed_shape
)
return
results
return
results
...
@@ -380,10 +370,8 @@ class PrimFuncNode(Node):
...
@@ -380,10 +370,8 @@ class PrimFuncNode(Node):
propagate_shape
=
shapes
[
arg
.
name
]
propagate_shape
=
shapes
[
arg
.
name
]
buffer_shape
=
args
[
i
].
shape
buffer_shape
=
args
[
i
].
shape
if
len
(
buffer_shape
)
>
len
(
propagate_shape
):
if
len
(
buffer_shape
)
>
len
(
propagate_shape
):
buffer_shape
=
buffer_shape
[
-
len
(
propagate_shape
):]
buffer_shape
=
buffer_shape
[
-
len
(
propagate_shape
)
:]
trimmed_shape
=
[
trimmed_shape
=
[
self
.
extent_wrapper
(
j
)
for
j
in
list
(
map
(
min
,
zip
(
propagate_shape
,
buffer_shape
)))]
self
.
extent_wrapper
(
j
)
for
j
in
list
(
map
(
min
,
zip
(
propagate_shape
,
buffer_shape
)))
]
results
.
append
(
trimmed_shape
)
results
.
append
(
trimmed_shape
)
return
results
return
results
...
@@ -412,10 +400,7 @@ class PrimFuncNode(Node):
...
@@ -412,10 +400,7 @@ class PrimFuncNode(Node):
def
get_reduce_inputs_dtype
(
self
):
def
get_reduce_inputs_dtype
(
self
):
if
self
.
reduction_block
is
None
:
if
self
.
reduction_block
is
None
:
return
{}
return
{}
return
{
return
{
b
.
name
:
tvm
.
DataType
(
b
.
dtype
)
for
b
in
self
.
block_analyzer
.
get_input_buffers
(
self
.
reduction_block
)}
b
.
name
:
tvm
.
DataType
(
b
.
dtype
)
for
b
in
self
.
block_analyzer
.
get_input_buffers
(
self
.
reduction_block
)
}
@
functools
.
lru_cache
@
functools
.
lru_cache
def
infer_tensorcore_axis
(
self
)
->
tuple
[
int
]:
def
infer_tensorcore_axis
(
self
)
->
tuple
[
int
]:
...
@@ -425,8 +410,7 @@ class PrimFuncNode(Node):
...
@@ -425,8 +410,7 @@ class PrimFuncNode(Node):
C_ax_m
,
C_ax_n
=
self
.
get_tag
(
"tensorcore_config"
)
C_ax_m
,
C_ax_n
=
self
.
get_tag
(
"tensorcore_config"
)
wmma_m
,
wmma_n
,
wmma_k
=
[
16
,
16
,
16
]
# just for testing, any number is ok
wmma_m
,
wmma_n
,
wmma_k
=
[
16
,
16
,
16
]
# just for testing, any number is ok
output_buffer_shape
=
(
output_buffer_shape
=
self
.
block_analyzer
.
sch
.
get
(
self
.
reduction_block
).
writes
[
0
].
buffer
.
shape
self
.
block_analyzer
.
sch
.
get
(
self
.
reduction_block
).
writes
[
0
].
buffer
.
shape
)
valid_region
=
[]
valid_region
=
[]
for
region
in
output_buffer_shape
:
for
region
in
output_buffer_shape
:
if
region
.
value
==
1
:
if
region
.
value
==
1
:
...
@@ -438,8 +422,7 @@ class PrimFuncNode(Node):
...
@@ -438,8 +422,7 @@ class PrimFuncNode(Node):
def
get_cl_shapes
(
c_ax_m
,
c_ax_n
,
num_nvalid_regions
):
def
get_cl_shapes
(
c_ax_m
,
c_ax_n
,
num_nvalid_regions
):
spatial_dim
=
self
.
get_space_dim
()
spatial_dim
=
self
.
get_space_dim
()
assert
len
(
valid_region
)
==
len
(
assert
len
(
valid_region
)
==
len
(
spatial_dim
),
f
"
{
valid_region
}
mismatch with
{
spatial_dim
}
"
spatial_dim
),
f
"
{
valid_region
}
mismatch with
{
spatial_dim
}
"
cl_shapes
=
[
1
]
*
len
(
spatial_dim
)
cl_shapes
=
[
1
]
*
len
(
spatial_dim
)
cl_shapes
[
c_ax_m
-
num_nvalid_regions
]
=
wmma_m
cl_shapes
[
c_ax_m
-
num_nvalid_regions
]
=
wmma_m
cl_shapes
[
c_ax_n
-
num_nvalid_regions
]
=
wmma_n
cl_shapes
[
c_ax_n
-
num_nvalid_regions
]
=
wmma_n
...
@@ -467,9 +450,11 @@ class PrimFuncNode(Node):
...
@@ -467,9 +450,11 @@ class PrimFuncNode(Node):
shapes
,
_
=
self
.
propagate
(
shape
,
rstep
)
shapes
,
_
=
self
.
propagate
(
shape
,
rstep
)
def
is_broadcast_pattern
(
buffer
,
output_buffer
):
def
is_broadcast_pattern
(
buffer
,
output_buffer
):
return
(
buffer
in
self
.
args
and
return
(
len
(
shapes
[
output_buffer
.
name
])
>
len
(
shapes
[
buffer
.
name
])
and
buffer
in
self
.
args
np
.
prod
(
shapes
[
output_buffer
.
name
])
>
np
.
prod
(
shapes
[
buffer
.
name
]))
and
len
(
shapes
[
output_buffer
.
name
])
>
len
(
shapes
[
buffer
.
name
])
and
np
.
prod
(
shapes
[
output_buffer
.
name
])
>
np
.
prod
(
shapes
[
buffer
.
name
])
)
def
is_after_reduce_stage
(
block
):
def
is_after_reduce_stage
(
block
):
if
not
self
.
reduction_block
:
if
not
self
.
reduction_block
:
...
@@ -491,8 +476,8 @@ class PrimFuncNode(Node):
...
@@ -491,8 +476,8 @@ class PrimFuncNode(Node):
output_buffer
=
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
]
output_buffer
=
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
]
for
buffer
in
self
.
block_analyzer
.
get_input_buffers
(
block
):
for
buffer
in
self
.
block_analyzer
.
get_input_buffers
(
block
):
cache
=
buffer
.
name
not
in
cached_tensor
and
(
cache
=
buffer
.
name
not
in
cached_tensor
and
(
is_broadcast_pattern
(
buffer
,
output_buffer
)
or
is_broadcast_pattern
(
buffer
,
output_buffer
)
or
self
.
block_analyzer
.
get_block_info
(
block
).
is_reduction
()
self
.
block_analyzer
.
get_block_info
(
block
).
is_reduction
()
)
)
if
not
cache
:
if
not
cache
:
continue
continue
cached_tensor
.
append
(
buffer
.
name
)
cached_tensor
.
append
(
buffer
.
name
)
...
@@ -500,8 +485,7 @@ class PrimFuncNode(Node):
...
@@ -500,8 +485,7 @@ class PrimFuncNode(Node):
continue
# cache after reduce op can often reuse buffer in reduce stage
continue
# cache after reduce op can often reuse buffer in reduce stage
if
buffer
.
name
in
stride_map
:
if
buffer
.
name
in
stride_map
:
num_elem
=
stride_map
[
buffer
.
name
].
compute_elements_from_shape
(
num_elem
=
stride_map
[
buffer
.
name
].
compute_elements_from_shape
(
shapes
[
buffer
.
name
])
shapes
[
buffer
.
name
])
else
:
else
:
num_elem
=
np
.
prod
(
shapes
[
buffer
.
name
])
num_elem
=
np
.
prod
(
shapes
[
buffer
.
name
])
buffer_len
=
num_elem
*
int
((
tvm
.
DataType
(
buffer
.
dtype
).
bits
+
7
)
//
8
)
buffer_len
=
num_elem
*
int
((
tvm
.
DataType
(
buffer
.
dtype
).
bits
+
7
)
//
8
)
...
@@ -514,7 +498,6 @@ class PrimFuncNode(Node):
...
@@ -514,7 +498,6 @@ class PrimFuncNode(Node):
class
OutputNode
(
Node
):
class
OutputNode
(
Node
):
def
__init__
(
self
,
node
,
id
=
0
):
def
__init__
(
self
,
node
,
id
=
0
):
super
().
__init__
(
name
=
"OutputNode"
)
super
().
__init__
(
name
=
"OutputNode"
)
# connect node and output node
# connect node and output node
...
@@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]:
...
@@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]:
input_ready_count
[
dst_node
]
=
len
(
dst_node
.
inputs
)
input_ready_count
[
dst_node
]
=
len
(
dst_node
.
inputs
)
list_of_nodes
.
append
(
dst_node
)
list_of_nodes
.
append
(
dst_node
)
input_ready_count
[
dst_node
]
-=
1
input_ready_count
[
dst_node
]
-=
1
assert
(
input_ready_count
[
dst_node
]
>=
0
)
assert
input_ready_count
[
dst_node
]
>=
0
if
input_ready_count
[
dst_node
]
==
0
:
if
input_ready_count
[
dst_node
]
==
0
:
ready
.
append
(
dst_node
)
ready
.
append
(
dst_node
)
assert
(
len
(
list_of_nodes
)
==
len
(
output_list
)
)
assert
len
(
list_of_nodes
)
==
len
(
output_list
)
return
output_list
return
output_list
def
find_topo_sort_priority
(
output_node_list
)
->
list
[
Node
]:
def
find_topo_sort_priority
(
output_node_list
)
->
list
[
Node
]:
import
sys
import
sys
sys
.
setrecursionlimit
(
10000
)
sys
.
setrecursionlimit
(
10000
)
def
topo_sort_get_layer
(
node
,
topo_layer
):
def
topo_sort_get_layer
(
node
,
topo_layer
):
...
@@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
...
@@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
if
node
in
visited
:
if
node
in
visited
:
return
return
visited
.
add
(
node
)
visited
.
add
(
node
)
ordered_input_nodes
=
sorted
([
edge
.
src_node
for
edge
in
node
.
inputs
],
ordered_input_nodes
=
sorted
([
edge
.
src_node
for
edge
in
node
.
inputs
],
key
=
lambda
n
:
topo_layer
[
n
],
reverse
=
True
)
key
=
lambda
n
:
topo_layer
[
n
],
reverse
=
True
)
for
n
in
ordered_input_nodes
:
for
n
in
ordered_input_nodes
:
topo_sort_dfs
(
n
,
visited
,
topo_order
)
topo_sort_dfs
(
n
,
visited
,
topo_order
)
topo_order
.
append
(
node
)
topo_order
.
append
(
node
)
...
@@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
...
@@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
def
find_topo_sort
(
output_node_list
)
->
list
[
Node
]:
def
find_topo_sort
(
output_node_list
)
->
list
[
Node
]:
def
topo_sort_dfs
(
node
,
visited
,
topo_order
):
def
topo_sort_dfs
(
node
,
visited
,
topo_order
):
if
node
in
visited
:
if
node
in
visited
:
return
return
...
...
tilelang/carver/roller/policy/default.py
View file @
29051439
"""Policy for cuda core schedule"""
"""Policy for cuda core schedule"""
from
__future__
import
annotations
from
__future__
import
annotations
import
functools
import
functools
import
math
import
math
...
@@ -36,20 +37,14 @@ class DefaultPolicy:
...
@@ -36,20 +37,14 @@ class DefaultPolicy:
self
.
rasterization
=
NoRasterization
()
self
.
rasterization
=
NoRasterization
()
@
classmethod
@
classmethod
def
from_prim_func
(
cls
,
def
from_prim_func
(
cls
,
func
:
tvm
.
tir
.
PrimFunc
,
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
):
func
:
tvm
.
tir
.
PrimFunc
,
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
):
return
cls
(
arch
,
tags
).
_init_with_prim_func
(
func
,
name
)
return
cls
(
arch
,
tags
).
_init_with_prim_func
(
func
,
name
)
@
classmethod
@
classmethod
def
from_output_nodes
(
cls
,
nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
):
def
from_output_nodes
(
cls
,
nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
):
return
cls
(
arch
,
tags
).
_init_with_output_nodes
(
nodes
)
return
cls
(
arch
,
tags
).
_init_with_output_nodes
(
nodes
)
def
_init_with_prim_func
(
self
,
def
_init_with_prim_func
(
self
,
func
:
tvm
.
tir
.
PrimFunc
,
name
:
str
=
"PrimFuncNode"
)
->
DefaultPolicy
:
func
:
tvm
.
tir
.
PrimFunc
,
name
:
str
=
"PrimFuncNode"
)
->
DefaultPolicy
:
if
func
is
not
None
and
isinstance
(
func
,
tvm
.
tir
.
PrimFunc
):
if
func
is
not
None
and
isinstance
(
func
,
tvm
.
tir
.
PrimFunc
):
self
.
func
=
func
self
.
func
=
func
self
.
prim_func_node
=
PrimFuncNode
(
self
.
func
,
tags
=
self
.
tags
,
name
=
name
)
self
.
prim_func_node
=
PrimFuncNode
(
self
.
func
,
tags
=
self
.
tags
,
name
=
name
)
...
@@ -60,9 +55,7 @@ class DefaultPolicy:
...
@@ -60,9 +55,7 @@ class DefaultPolicy:
return
self
return
self
def
_init_with_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
]):
def
_init_with_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
]):
self
.
ordered_nodes
=
list
(
self
.
ordered_nodes
=
list
(
filter
(
lambda
n
:
not
n
.
is_placeholder
()
and
not
n
.
is_output
(),
find_topo_sort
(
output_nodes
)))
filter
(
lambda
n
:
not
n
.
is_placeholder
()
and
not
n
.
is_output
(),
find_topo_sort
(
output_nodes
)))
for
node
in
self
.
ordered_nodes
:
for
node
in
self
.
ordered_nodes
:
node
.
update_tags
(
self
.
tags
)
node
.
update_tags
(
self
.
tags
)
...
@@ -102,13 +95,14 @@ class DefaultPolicy:
...
@@ -102,13 +95,14 @@ class DefaultPolicy:
def
dfs_smem_tile
(
self
,
init_tile
,
rstep_map
)
->
Iterable
[
TileDict
]:
def
dfs_smem_tile
(
self
,
init_tile
,
rstep_map
)
->
Iterable
[
TileDict
]:
_steps
=
[
get_all_factors
(
n
)
for
n
in
self
.
output_nodes
[
0
].
get_space_dim
()]
_steps
=
[
get_all_factors
(
n
)
for
n
in
self
.
output_nodes
[
0
].
get_space_dim
()]
steps
=
[
step
[
step
.
index
(
t
):]
for
step
,
t
in
zip
(
_steps
,
init_tile
)]
steps
=
[
step
[
step
.
index
(
t
)
:]
for
step
,
t
in
zip
(
_steps
,
init_tile
)]
for
i
in
range
(
len
(
steps
)):
for
i
in
range
(
len
(
steps
)):
added
=
list
(
added
=
list
(
filter
(
filter
(
lambda
s
:
s
<
steps
[
i
][
-
1
]
and
s
>
steps
[
i
][
0
]
and
s
not
in
steps
[
i
],
lambda
s
:
s
<
steps
[
i
][
-
1
]
and
s
>
steps
[
i
][
0
]
and
s
not
in
steps
[
i
],
[
2
,
4
,
8
,
16
,
32
],
[
2
,
4
,
8
,
16
,
32
],
))
)
)
steps
[
i
].
extend
(
added
)
steps
[
i
].
extend
(
added
)
steps
[
i
]
=
sorted
(
steps
[
i
])
steps
[
i
]
=
sorted
(
steps
[
i
])
visited_tiles
=
{}
visited_tiles
=
{}
...
@@ -190,10 +184,7 @@ class DefaultPolicy:
...
@@ -190,10 +184,7 @@ class DefaultPolicy:
"""
"""
tile_map
=
{}
tile_map
=
{}
for
node
in
self
.
output_nodes
:
for
node
in
self
.
output_nodes
:
tile_map
[
node
]
=
[
tile_map
[
node
]
=
[
tile
[
i
]
*
node
.
get_space_dim
()[
i
]
//
self
.
output_nodes
[
0
].
get_space_dim
()[
i
]
for
i
in
range
(
len
(
tile
))]
tile
[
i
]
*
node
.
get_space_dim
()[
i
]
//
self
.
output_nodes
[
0
].
get_space_dim
()[
i
]
for
i
in
range
(
len
(
tile
))
]
return
tile_map
return
tile_map
def
compute_workload_per_item
(
self
,
output_tile
)
->
float
:
def
compute_workload_per_item
(
self
,
output_tile
)
->
float
:
...
@@ -304,8 +295,7 @@ class DefaultPolicy:
...
@@ -304,8 +295,7 @@ class DefaultPolicy:
score
=
0
score
=
0
shape
=
node
.
propagate_inputs
(
tile
,
rstep
=
rstep
)
shape
=
node
.
propagate_inputs
(
tile
,
rstep
=
rstep
)
for
i
,
input_buffer
in
enumerate
(
node
.
input_buffers
):
for
i
,
input_buffer
in
enumerate
(
node
.
input_buffers
):
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
(
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
((
node
.
get_buffer_dtype
(
input_buffer
).
bits
+
7
)
//
8
)
(
node
.
get_buffer_dtype
(
input_buffer
).
bits
+
7
)
//
8
)
score
+=
sim
(
score
+=
sim
(
int
(
coalesced_factor
(
shape
[
i
],
input_buffer
.
shape
)),
int
(
coalesced_factor
(
shape
[
i
],
input_buffer
.
shape
)),
read_transaction_elements
,
read_transaction_elements
,
...
@@ -380,17 +370,13 @@ class DefaultPolicy:
...
@@ -380,17 +370,13 @@ class DefaultPolicy:
return
None
return
None
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
cur_rstep_id
=
{
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
new_rstep_map
=
rstep_map
.
copy
()
new_rstep_map
=
rstep_map
.
copy
()
while
True
:
while
True
:
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
if
new_rstep_id
is
None
:
if
new_rstep_id
is
None
:
break
break
new_rstep_map
[
node
]
=
{
new_rstep_map
[
node
]
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
old_rstep_map
=
td
.
rstep_map
old_rstep_map
=
td
.
rstep_map
td
.
rstep_map
=
new_rstep_map
td
.
rstep_map
=
new_rstep_map
smem_usage
,
_
=
self
.
_compute_shared_memory_usage
(
td
)
smem_usage
,
_
=
self
.
_compute_shared_memory_usage
(
td
)
...
@@ -434,15 +420,14 @@ class DefaultPolicy:
...
@@ -434,15 +420,14 @@ class DefaultPolicy:
if
edge
.
src_node
.
is_placeholder
():
if
edge
.
src_node
.
is_placeholder
():
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
nbytes
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
nbytes
traffic
+=
coalesced_tensor_shape
(
input_shapes
[
i
],
edge
.
src_node
.
get_shape
(),
traffic
+=
coalesced_tensor_shape
(
input_shapes
[
i
],
edge
.
src_node
.
get_shape
(),
read_transaction_elements
)
*
nbytes
read_transaction_elements
)
*
nbytes
for
edge
in
node
.
outputs
:
for
edge
in
node
.
outputs
:
if
edge
.
dst_node
.
is_output
():
if
edge
.
dst_node
.
is_output
():
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
write_transaction_elements
=
self
.
arch
.
transaction_size
[
0
]
//
nbytes
write_transaction_elements
=
self
.
arch
.
transaction_size
[
0
]
//
nbytes
traffic
+=
coalesced_tensor_shape
(
output_shapes
[
edge
.
src_id
],
traffic
+=
(
node
.
get_shape
(
edge
.
src_id
),
coalesced_tensor_shape
(
output_shapes
[
edge
.
src_id
],
node
.
get_shape
(
edge
.
src_id
),
write_transaction_elements
)
*
nbytes
write_transaction_elements
)
*
nbytes
)
return
traffic
,
op_tile_map
return
traffic
,
op_tile_map
...
@@ -487,10 +472,7 @@ class DefaultPolicy:
...
@@ -487,10 +472,7 @@ class DefaultPolicy:
cached_tensors_map
=
{}
cached_tensors_map
=
{}
def
can_free
(
node
,
out_id
):
def
can_free
(
node
,
out_id
):
for
edge
in
node
.
outputs
:
return
all
(
not
(
edge
.
src_id
==
out_id
and
edge
.
dst_node
not
in
processed
)
for
edge
in
node
.
outputs
)
if
edge
.
src_id
==
out_id
and
edge
.
dst_node
not
in
processed
:
return
False
return
True
for
node
in
self
.
ordered_nodes
:
for
node
in
self
.
ordered_nodes
:
node_internal_bytes
,
cached_tensors_map
[
node
]
=
self
.
infer_node_smem_usage
(
td
,
node
)
node_internal_bytes
,
cached_tensors_map
[
node
]
=
self
.
infer_node_smem_usage
(
td
,
node
)
...
@@ -528,9 +510,7 @@ class DefaultPolicy:
...
@@ -528,9 +510,7 @@ class DefaultPolicy:
Tuple[Dict, Dict]
Tuple[Dict, Dict]
A tuple of dictionaries containing the output strides and tensor strides.
A tuple of dictionaries containing the output strides and tensor strides.
"""
"""
output_strides
=
{
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)}
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)
}
tensor_strides
=
{}
tensor_strides
=
{}
return
output_strides
,
tensor_strides
return
output_strides
,
tensor_strides
...
@@ -551,8 +531,7 @@ class DefaultPolicy:
...
@@ -551,8 +531,7 @@ class DefaultPolicy:
output_strides_map
=
{}
output_strides_map
=
{}
tensor_strides_map
=
{}
tensor_strides_map
=
{}
for
node
in
self
.
ordered_nodes
:
for
node
in
self
.
ordered_nodes
:
output_strides_map
[
node
],
tensor_strides_map
[
node
]
=
self
.
compute_node_stride_map
(
output_strides_map
[
node
],
tensor_strides_map
[
node
]
=
self
.
compute_node_stride_map
(
node
,
td
)
node
,
td
)
td
.
output_strides_map
,
td
.
tensor_strides_map
=
output_strides_map
,
tensor_strides_map
td
.
output_strides_map
,
td
.
tensor_strides_map
=
output_strides_map
,
tensor_strides_map
def
compute_tile_dict
(
self
,
output_tile
:
list
[
int
],
rstep_map
)
->
TileDict
:
def
compute_tile_dict
(
self
,
output_tile
:
list
[
int
],
rstep_map
)
->
TileDict
:
...
@@ -582,9 +561,7 @@ class DefaultPolicy:
...
@@ -582,9 +561,7 @@ class DefaultPolicy:
output_shape
=
self
.
output_nodes
[
0
].
get_space_dim
()
output_shape
=
self
.
output_nodes
[
0
].
get_space_dim
()
td
.
grid_size
=
int
(
np
.
prod
([(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
output_tile
,
output_shape
)]))
td
.
grid_size
=
int
(
np
.
prod
([(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
output_tile
,
output_shape
)]))
# estimated reg usage
# estimated reg usage
reg_usage
=
int
(
2
*
max
([
reg_usage
=
int
(
2
*
max
([
np
.
prod
(
td
.
get_tile
(
node
))
*
node
.
get_dtype
().
bits
/
32
for
node
in
self
.
ordered_nodes
]))
np
.
prod
(
td
.
get_tile
(
node
))
*
node
.
get_dtype
().
bits
/
32
for
node
in
self
.
ordered_nodes
]))
if
reg_usage
>
self
.
arch
.
reg_cap
:
if
reg_usage
>
self
.
arch
.
reg_cap
:
td
.
valid
=
False
td
.
valid
=
False
return
td
return
td
...
@@ -609,13 +586,10 @@ class DefaultPolicy:
...
@@ -609,13 +586,10 @@ class DefaultPolicy:
for
node
in
self
.
ordered_nodes
:
for
node
in
self
.
ordered_nodes
:
if
np
.
prod
(
td
.
get_tile
(
node
))
==
0
:
if
np
.
prod
(
td
.
get_tile
(
node
))
==
0
:
return
False
return
False
node_grid_size
=
np
.
prod
([
node_grid_size
=
np
.
prod
([(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
td
.
get_tile
(
node
),
node
.
get_space_dim
())])
(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
td
.
get_tile
(
node
),
node
.
get_space_dim
())
])
if
node_grid_size
!=
td
.
grid_size
:
if
node_grid_size
!=
td
.
grid_size
:
return
False
return
False
if
(
hasattr
(
node
,
"reduce_op"
)
and
node
.
reduce_op
is
not
None
and
if
hasattr
(
node
,
"reduce_op"
)
and
node
.
reduce_op
is
not
None
and
len
(
node
.
reduce_op
.
axis
)
==
len
(
td
.
output_tile
):
len
(
node
.
reduce_op
.
axis
)
==
len
(
td
.
output_tile
)):
for
i
,
tile_extent
in
enumerate
(
td
.
output_tile
):
for
i
,
tile_extent
in
enumerate
(
td
.
output_tile
):
if
node
.
reduce_op
.
axis
[
i
].
dom
.
extent
%
tile_extent
:
if
node
.
reduce_op
.
axis
[
i
].
dom
.
extent
%
tile_extent
:
return
False
return
False
...
@@ -639,23 +613,22 @@ class DefaultPolicy:
...
@@ -639,23 +613,22 @@ class DefaultPolicy:
node_space_sizes
=
[
int
(
np
.
prod
(
td
.
get_tile
(
node
)))
for
node
in
self
.
ordered_nodes
]
node_space_sizes
=
[
int
(
np
.
prod
(
td
.
get_tile
(
node
)))
for
node
in
self
.
ordered_nodes
]
max_block_size
=
functools
.
reduce
(
math
.
gcd
,
node_space_sizes
)
max_block_size
=
functools
.
reduce
(
math
.
gcd
,
node_space_sizes
)
if
max_block_size
<
self
.
arch
.
warp_size
*
self
.
arch
.
sm_partition
and
max_block_size
==
min
(
if
max_block_size
<
self
.
arch
.
warp_size
*
self
.
arch
.
sm_partition
and
max_block_size
==
min
(
node_space_sizes
):
node_space_sizes
):
node_reduce_sizes
=
[
int
(
np
.
prod
(
list
(
td
.
get_rstep
(
node
).
values
())))
for
node
in
self
.
ordered_nodes
]
node_reduce_sizes
=
[
int
(
np
.
prod
(
list
(
td
.
get_rstep
(
node
).
values
())))
for
node
in
self
.
ordered_nodes
]
total_sizes
=
[
x
*
y
for
x
,
y
in
zip
(
node_space_sizes
,
node_reduce_sizes
)]
total_sizes
=
[
x
*
y
for
x
,
y
in
zip
(
node_space_sizes
,
node_reduce_sizes
)]
max_possible_size
=
functools
.
reduce
(
math
.
gcd
,
total_sizes
)
max_possible_size
=
functools
.
reduce
(
math
.
gcd
,
total_sizes
)
possible_block_sizes
=
list
(
possible_block_sizes
=
list
(
filter
(
filter
(
lambda
x
:
x
%
max_block_size
==
0
and
x
<=
1024
,
lambda
x
:
x
%
max_block_size
==
0
and
x
<=
1024
,
get_all_factors
(
max_possible_size
),
get_all_factors
(
max_possible_size
),
))
)
)
possible_block_sizes
=
list
(
possible_block_sizes
=
list
(
filter
(
# either be a factor of space or cover fully cover the space
filter
(
# either be a factor of space or cover fully cover the space
lambda
x
:
all
([
x
%
s
==
0
or
s
%
x
==
0
for
s
in
node_space_sizes
]),
lambda
x
:
all
([
x
%
s
==
0
or
s
%
x
==
0
for
s
in
node_space_sizes
]),
possible_block_sizes
,
possible_block_sizes
,
))
)
)
factor_ordered
=
sorted
(
possible_block_sizes
,
key
=
self
.
score_block_size
)
factor_ordered
=
sorted
(
possible_block_sizes
,
key
=
self
.
score_block_size
)
return
factor_ordered
return
factor_ordered
else
:
else
:
...
@@ -821,8 +794,7 @@ class DefaultPolicy:
...
@@ -821,8 +794,7 @@ class DefaultPolicy:
vectorize_result
=
{}
vectorize_result
=
{}
for
tensor
,
shape
in
shapes
.
items
():
for
tensor
,
shape
in
shapes
.
items
():
for
v
in
vectorize_sizes
:
for
v
in
vectorize_sizes
:
if
(
is_shape_aligned
(
shape
,
block_size
*
v
)
and
is_cont
(
shape
,
v
)
and
if
is_shape_aligned
(
shape
,
block_size
*
v
)
and
is_cont
(
shape
,
v
)
and
is_type_allowed
(
dtypes
[
tensor
],
v
):
is_type_allowed
(
dtypes
[
tensor
],
v
)):
vectorize_result
[
tensor
]
=
v
vectorize_result
[
tensor
]
=
v
break
break
return
vectorize_result
return
vectorize_result
...
...
tilelang/carver/roller/policy/tensorcore.py
View file @
29051439
"""Policy for tensorcore schedule"""
"""Policy for tensorcore schedule"""
from
__future__
import
annotations
from
__future__
import
annotations
import
tvm
import
tvm
import
numpy
as
np
import
numpy
as
np
...
@@ -13,7 +14,6 @@ logger = logging.getLogger(__name__)
...
@@ -13,7 +14,6 @@ logger = logging.getLogger(__name__)
class
TensorCorePolicy
(
DefaultPolicy
):
class
TensorCorePolicy
(
DefaultPolicy
):
# this is the trick for wmma.
# this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32.
# However, for int8 mma, the wmma_k should be 32.
wmma_k
:
int
=
16
wmma_k
:
int
=
16
...
@@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy):
A_high_ax
=
min
(
A_ax_m
,
A_ax_k
)
A_high_ax
=
min
(
A_ax_m
,
A_ax_k
)
B_high_ax
=
min
(
B_ax_n
,
B_ax_k
)
B_high_ax
=
min
(
B_ax_n
,
B_ax_k
)
C_high_ax
=
min
(
C_ax_m
,
C_ax_n
)
C_high_ax
=
min
(
C_ax_m
,
C_ax_n
)
A_stride
=
Stride
(
stride
=
np
.
prod
(
AS_shape
[
A_high_ax
+
1
:])
+
offset
,
ax
=
A_high_ax
)
A_stride
=
Stride
(
stride
=
np
.
prod
(
AS_shape
[
A_high_ax
+
1
:])
+
offset
,
ax
=
A_high_ax
)
B_stride
=
Stride
(
stride
=
np
.
prod
(
BS_shape
[
B_high_ax
+
1
:])
+
offset
,
ax
=
B_high_ax
)
B_stride
=
Stride
(
stride
=
np
.
prod
(
BS_shape
[
B_high_ax
+
1
:])
+
offset
,
ax
=
B_high_ax
)
C_stride
=
Stride
(
stride
=
np
.
prod
(
CS_shape
[
C_high_ax
+
1
:])
+
offset
,
ax
=
C_high_ax
)
C_stride
=
Stride
(
stride
=
np
.
prod
(
CS_shape
[
C_high_ax
+
1
:])
+
offset
,
ax
=
C_high_ax
)
return
A_stride
,
B_stride
,
C_stride
return
A_stride
,
B_stride
,
C_stride
def
infer_node_smem_usage
(
self
,
td
:
TileDict
,
node
:
PrimFuncNode
):
def
infer_node_smem_usage
(
self
,
td
:
TileDict
,
node
:
PrimFuncNode
):
...
@@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy):
# get reduce input size
# get reduce input size
target_transaction
=
self
.
arch
.
transaction_size
[
0
]
*
2
target_transaction
=
self
.
arch
.
transaction_size
[
0
]
*
2
# 512 bytes // type bits
# 512 bytes // type bits
reduce_input_dtype
=
node
.
get_buffer_dtype
(
reduce_input_dtype
=
node
.
get_buffer_dtype
(
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)[
0
])
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)[
0
])
basic
=
(
target_transaction
*
8
)
//
reduce_input_dtype
.
bits
basic
=
(
target_transaction
*
8
)
//
reduce_input_dtype
.
bits
result
=
{}
result
=
{}
...
@@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy):
iter_name
=
iter_info
.
var
.
name
iter_name
=
iter_info
.
var
.
name
iter_dom
=
iter_info
.
dom
.
extent
iter_dom
=
iter_info
.
dom
.
extent
if
iter_dom
%
16
>
0
:
if
iter_dom
%
16
>
0
:
result
[
iter_name
]
=
(
16
if
iter_dom
<
basic
else
basic
)
# for the case of padding
result
[
iter_name
]
=
16
if
iter_dom
<
basic
else
basic
# for the case of padding
elif
iter_dom
%
basic
==
0
:
elif
iter_dom
%
basic
==
0
:
result
[
iter_name
]
=
basic
result
[
iter_name
]
=
basic
else
:
else
:
...
@@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy):
return
False
return
False
if
_check_small_tile
(
td
):
if
_check_small_tile
(
td
):
smem_limit
=
min
(
self
.
arch
.
max_smem_usage
//
td
.
block_per_SM
,
self
.
arch
.
smem_cap
)
smem_limit
=
min
(
self
.
arch
.
max_smem_usage
//
td
.
block_per_SM
,
self
.
arch
.
smem_cap
)
rstep_map
=
td
.
rstep_map
.
copy
()
rstep_map
=
td
.
rstep_map
.
copy
()
...
@@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy):
return
rstep
return
rstep
def
_shared_memory_usage
(
td
:
TileDict
):
def
_shared_memory_usage
(
td
:
TileDict
):
return
node
.
footprint
(
td
.
output_tile
,
new_rstep_map
,
return
node
.
footprint
(
td
.
output_tile
,
new_rstep_map
,
td
.
tensor_strides_map
[
node
])
td
.
tensor_strides_map
[
node
])
def
_score
(
rstep_id
):
def
_score
(
rstep_id
):
rstep
=
{
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
score
=
0
score
=
0
shape
=
node
.
propagate_inputs_on_reduction
(
td
.
get_tile
(
node
),
rstep
=
rstep
)
shape
=
node
.
propagate_inputs_on_reduction
(
td
.
get_tile
(
node
),
rstep
=
rstep
)
input_buffers
=
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)
input_buffers
=
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)
...
@@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy):
return
None
return
None
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
cur_rstep_id
=
{
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
new_rstep_map
=
rstep_map
.
copy
()
new_rstep_map
=
rstep_map
.
copy
()
while
True
:
while
True
:
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
if
new_rstep_id
is
None
:
if
new_rstep_id
is
None
:
break
break
new_rstep_map
=
{
new_rstep_map
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
old_rstep_map
=
td
.
rstep_map
old_rstep_map
=
td
.
rstep_map
td
.
rstep_map
=
new_rstep_map
td
.
rstep_map
=
new_rstep_map
smem_usage
,
_
=
_shared_memory_usage
(
td
)
smem_usage
,
_
=
_shared_memory_usage
(
td
)
...
@@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy):
break
break
else
:
else
:
cur_rstep_id
=
new_rstep_id
cur_rstep_id
=
new_rstep_id
rstep
=
{
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
cur_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
cur_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
return
rstep
return
rstep
for
node
in
self
.
ordered_nodes
:
for
node
in
self
.
ordered_nodes
:
...
@@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy):
return
super
().
get_node_reduce_step_candidates
(
node
)
return
super
().
get_node_reduce_step_candidates
(
node
)
else
:
else
:
# must be a a multiple of wmma_k
# must be a a multiple of wmma_k
return
{
return
{
k
.
var
.
name
:
[
x
*
self
.
wmma_k
for
x
in
get_all_factors
(
int
(
k
.
dom
.
extent
)
//
self
.
wmma_k
)]
for
k
in
node
.
raxis
}
k
.
var
.
name
:
[
x
*
self
.
wmma_k
for
x
in
get_all_factors
(
int
(
k
.
dom
.
extent
)
//
self
.
wmma_k
)
]
for
k
in
node
.
raxis
}
def
check_tile_shape_isvalid
(
self
,
td
:
TileDict
):
def
check_tile_shape_isvalid
(
self
,
td
:
TileDict
):
for
node
in
self
.
ordered_nodes
:
for
node
in
self
.
ordered_nodes
:
...
@@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy):
td
.
tile_map
[
node
][
ax_n
],
td
.
tile_map
[
node
][
ax_n
],
)
)
# check the tile size is valid
# check the tile size is valid
wmma_invalid
=
[
wmma_invalid
=
[
block_m
<
wmma_m
or
block_n
<
wmma_n
for
wmma_m
,
wmma_n
in
self
.
arch
.
get_avaliable_tensorintrin_shapes
()]
block_m
<
wmma_m
or
block_n
<
wmma_n
for
wmma_m
,
wmma_n
in
self
.
arch
.
get_avaliable_tensorintrin_shapes
()
]
if
all
(
wmma_invalid
):
if
all
(
wmma_invalid
):
return
False
return
False
if
any
([
y
%
x
for
x
,
y
in
zip
(
td
.
tile_map
[
node
],
node
.
get_space_dim
())]):
if
any
([
y
%
x
for
x
,
y
in
zip
(
td
.
tile_map
[
node
],
node
.
get_space_dim
())]):
...
@@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy):
return
super
().
compute_node_stride_map
(
node
,
td
)
return
super
().
compute_node_stride_map
(
node
,
td
)
use_layout
=
self
.
_can_implement_layout
(
node
,
td
)
use_layout
=
self
.
_can_implement_layout
(
node
,
td
)
AS_stride
,
BS_stride
,
C_stride
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
),
AS_stride
,
BS_stride
,
C_stride
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
),
td
.
get_rstep
(
node
))
td
.
get_rstep
(
node
))
A_stride
,
B_stride
,
_
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
))
A_stride
,
B_stride
,
_
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
))
tensor_strides
=
{}
tensor_strides
=
{}
output_strides
=
{
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)}
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)
}
tensor_strides
=
{}
tensor_strides
=
{}
# when connected to shared input, should use full stride without rstep
# when connected to shared input, should use full stride without rstep
for
i
,
(
_
,
_
)
in
enumerate
(
zip
([
AS_stride
,
BS_stride
],
[
A_stride
,
B_stride
])):
for
i
,
(
_
,
_
)
in
enumerate
(
zip
([
AS_stride
,
BS_stride
],
[
A_stride
,
B_stride
])):
...
@@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy):
...
@@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy):
overall_gmem_size_in_bytes
:
int
=
0
overall_gmem_size_in_bytes
:
int
=
0
for
node
in
self
.
ordered_nodes
:
for
node
in
self
.
ordered_nodes
:
for
buffer
in
node
.
input_buffers
:
for
buffer
in
node
.
input_buffers
:
overall_gmem_size_in_bytes
+=
(
overall_gmem_size_in_bytes
+=
int
(
np
.
prod
(
buffer
.
shape
))
*
tvm
.
DataType
(
buffer
.
dtype
).
bits
//
8
int
(
np
.
prod
(
buffer
.
shape
))
*
tvm
.
DataType
(
buffer
.
dtype
).
bits
//
8
)
return
overall_gmem_size_in_bytes
<
self
.
arch
.
l2_cache_size_bytes
return
overall_gmem_size_in_bytes
<
self
.
arch
.
l2_cache_size_bytes
conditions
.
append
(
_check_memory_size
())
conditions
.
append
(
_check_memory_size
())
...
...
tilelang/carver/roller/rasterization.py
View file @
29051439
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
class
Rasterization
:
class
Rasterization
:
panel_width_
=
None
panel_width_
=
None
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
...
@@ -18,7 +17,6 @@ class Rasterization:
...
@@ -18,7 +17,6 @@ class Rasterization:
class
NoRasterization
(
Rasterization
):
class
NoRasterization
(
Rasterization
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
...
tilelang/carver/roller/shape_inference/common.py
View file @
29051439
...
@@ -4,9 +4,7 @@ from tvm import arith
...
@@ -4,9 +4,7 @@ from tvm import arith
class
Statement
:
class
Statement
:
def
__init__
(
self
,
output
:
str
,
dependent_region
:
dict
,
var_map
:
OrderedDict
,
range_map
:
OrderedDict
):
def
__init__
(
self
,
output
:
str
,
dependent_region
:
dict
,
var_map
:
OrderedDict
,
range_map
:
OrderedDict
):
self
.
output
=
output
self
.
output
=
output
self
.
dependent_region
=
dependent_region
self
.
dependent_region
=
dependent_region
self
.
var_map
=
var_map
self
.
var_map
=
var_map
...
@@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
...
@@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
class
InputShapeInference
:
class
InputShapeInference
:
def
__init__
(
self
,
deps
:
list
[
Statement
]):
def
__init__
(
self
,
deps
:
list
[
Statement
]):
self
.
deps
=
deps
self
.
deps
=
deps
...
...
tilelang/carver/roller/shape_inference/tir.py
View file @
29051439
...
@@ -5,7 +5,6 @@ from tvm import arith, tir
...
@@ -5,7 +5,6 @@ from tvm import arith, tir
class
Statement
:
class
Statement
:
def
__init__
(
self
,
block_analyzer
,
block
:
BlockRV
):
def
__init__
(
self
,
block_analyzer
,
block
:
BlockRV
):
self
.
block_analyzer
=
block_analyzer
self
.
block_analyzer
=
block_analyzer
self
.
block
=
block
self
.
block
=
block
...
@@ -21,9 +20,7 @@ class Statement:
...
@@ -21,9 +20,7 @@ class Statement:
if
len
(
self
.
dependent_region
[
input_name
])
!=
1
:
if
len
(
self
.
dependent_region
[
input_name
])
!=
1
:
return
None
return
None
indices
=
self
.
dependent_region
[
input_name
][
0
]
indices
=
self
.
dependent_region
[
input_name
][
0
]
iter_map_range
=
{
iter_map_range
=
{
_iter
.
var
:
_iter
.
dom
for
_iter
in
self
.
block_analyzer
.
get_spatial_axis
(
self
.
block
)}
_iter
.
var
:
_iter
.
dom
for
_iter
in
self
.
block_analyzer
.
get_spatial_axis
(
self
.
block
)
}
iter_map_result
=
arith
.
detect_iter_map
(
iter_map_result
=
arith
.
detect_iter_map
(
indices
,
indices
,
iter_map_range
,
iter_map_range
,
...
@@ -77,7 +74,6 @@ class TensorDepNode:
...
@@ -77,7 +74,6 @@ class TensorDepNode:
class
DependencyAnalysis
:
class
DependencyAnalysis
:
def
__init__
(
self
,
deps
):
def
__init__
(
self
,
deps
):
self
.
deps
=
deps
self
.
deps
=
deps
# issue: duplicate name when we have two same ops.
# issue: duplicate name when we have two same ops.
...
@@ -112,8 +108,7 @@ class DependencyAnalysis:
...
@@ -112,8 +108,7 @@ class DependencyAnalysis:
def
traverse_dependencies
(
self
,
compute
):
def
traverse_dependencies
(
self
,
compute
):
if
isinstance
(
compute
,
Statement
):
if
isinstance
(
compute
,
Statement
):
node
=
self
.
get_or_create_node
(
node
=
self
.
get_or_create_node
(
compute
.
block_analyzer
.
get_output_buffers
(
compute
.
block
)[
0
].
name
)
compute
.
block_analyzer
.
get_output_buffers
(
compute
.
block
)[
0
].
name
)
# Loop through input tensors
# Loop through input tensors
for
input_buffer
in
compute
.
block_analyzer
.
get_input_buffers
(
compute
.
block
):
for
input_buffer
in
compute
.
block_analyzer
.
get_input_buffers
(
compute
.
block
):
# Get the input node
# Get the input node
...
@@ -167,7 +162,6 @@ class DependencyAnalysis:
...
@@ -167,7 +162,6 @@ class DependencyAnalysis:
class
InputShapeInference
:
class
InputShapeInference
:
def
__init__
(
self
,
deps
:
list
[
Statement
]):
def
__init__
(
self
,
deps
:
list
[
Statement
]):
self
.
deps
=
deps
self
.
deps
=
deps
self
.
target_mapping
=
{}
self
.
target_mapping
=
{}
...
@@ -183,16 +177,11 @@ class InputShapeInference:
...
@@ -183,16 +177,11 @@ class InputShapeInference:
if
targets
in
self
.
target_mapping
:
if
targets
in
self
.
target_mapping
:
return
self
.
target_mapping
[
targets
]
return
self
.
target_mapping
[
targets
]
# should be buffer name instead of block name
# should be buffer name instead of block name
name2dep
=
{
name2dep
=
{
dep
.
block_analyzer
.
get_output_buffers
(
dep
.
block
)[
0
].
name
:
dep
for
dep
in
self
.
deps
}
dep
.
block_analyzer
.
get_output_buffers
(
dep
.
block
)[
0
].
name
:
dep
for
dep
in
self
.
deps
}
mapping
=
{}
mapping
=
{}
input_vars
=
[]
input_vars
=
[]
for
target
in
targets
:
for
target
in
targets
:
vars
=
[
vars
=
[
iter
.
var
for
iter
in
name2dep
[
target
].
block_analyzer
.
get_spatial_axis
(
name2dep
[
target
].
block
)]
iter
.
var
for
iter
in
name2dep
[
target
].
block_analyzer
.
get_spatial_axis
(
name2dep
[
target
].
block
)
]
input_vars
.
append
(
vars
)
input_vars
.
append
(
vars
)
mapping
[
target
]
=
[
vars
]
mapping
[
target
]
=
[
vars
]
ana
=
arith
.
Analyzer
()
ana
=
arith
.
Analyzer
()
...
@@ -221,13 +210,8 @@ class InputShapeInference:
...
@@ -221,13 +210,8 @@ class InputShapeInference:
mapping
[
input_name
]
=
[]
mapping
[
input_name
]
=
[]
for
indices
in
indices_list
:
for
indices
in
indices_list
:
for
region
in
regions
:
for
region
in
regions
:
vmap
=
{
vmap
=
{
k
:
(
tir
.
Cast
(
k
.
dtype
,
v
)
if
v
.
dtype
!=
k
.
dtype
else
v
)
for
k
,
v
in
zip
(
ax_vars
,
indices
)}
k
:
(
tir
.
Cast
(
k
.
dtype
,
v
)
if
v
.
dtype
!=
k
.
dtype
else
v
)
region
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
ax
,
vmap
))
for
ax
in
region
]
for
k
,
v
in
zip
(
ax_vars
,
indices
)
}
region
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
ax
,
vmap
))
for
ax
in
region
]
if
not
region_exist_in_list
(
region
,
mapping
[
input_name
]):
if
not
region_exist_in_list
(
region
,
mapping
[
input_name
]):
mapping
[
input_name
].
append
(
region
)
mapping
[
input_name
].
append
(
region
)
buffers
=
[]
buffers
=
[]
...
@@ -241,10 +225,7 @@ class InputShapeInference:
...
@@ -241,10 +225,7 @@ class InputShapeInference:
self
.
target_mapping
[
targets
]
=
input_vars
,
mapping
self
.
target_mapping
[
targets
]
=
input_vars
,
mapping
return
input_vars
,
mapping
return
input_vars
,
mapping
def
infer
(
self
,
def
infer
(
self
,
shape
:
dict
[
str
,
list
[
arith
.
ConstIntBound
]],
rstep
:
dict
[
str
,
int
]
=
None
,
targets
=
None
):
shape
:
dict
[
str
,
list
[
arith
.
ConstIntBound
]],
rstep
:
dict
[
str
,
int
]
=
None
,
targets
=
None
):
if
rstep
is
None
:
if
rstep
is
None
:
rstep
=
{}
rstep
=
{}
compute_targets
=
tuple
(
shape
.
keys
())
compute_targets
=
tuple
(
shape
.
keys
())
...
@@ -258,8 +239,7 @@ class InputShapeInference:
...
@@ -258,8 +239,7 @@ class InputShapeInference:
for
ax
in
self
.
reduce_axes
:
for
ax
in
self
.
reduce_axes
:
# assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value.
# assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value.
if
ax
.
var
.
name
in
rstep
:
if
ax
.
var
.
name
in
rstep
:
bound
=
arith
.
ConstIntBound
(
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
min
(
ax
.
dom
.
extent
,
rstep
[
ax
.
var
.
name
])
-
1
))
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
min
(
ax
.
dom
.
extent
,
rstep
[
ax
.
var
.
name
])
-
1
))
else
:
else
:
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
ax
.
dom
.
extent
-
1
))
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
ax
.
dom
.
extent
-
1
))
ana
.
update
(
ax
.
var
,
bound
,
True
)
ana
.
update
(
ax
.
var
,
bound
,
True
)
...
@@ -312,14 +292,11 @@ class InputShapeInference:
...
@@ -312,14 +292,11 @@ class InputShapeInference:
for
name
,
regions
in
mapping
.
items
():
for
name
,
regions
in
mapping
.
items
():
region
=
regions
[
0
]
region
=
regions
[
0
]
result
[
name
]
=
[
result
[
name
]
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
index
,
vmap
))
for
index
in
region
]
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
index
,
vmap
))
for
index
in
region
]
return
result
return
result
def
region_exist_in_list
(
a
,
list
)
->
bool
:
def
region_exist_in_list
(
a
,
list
)
->
bool
:
def
expr_is_same
(
a
,
b
)
->
bool
:
def
expr_is_same
(
a
,
b
)
->
bool
:
if
isinstance
(
a
,
tir
.
IntImm
)
and
isinstance
(
b
,
tir
.
IntImm
):
if
isinstance
(
a
,
tir
.
IntImm
)
and
isinstance
(
b
,
tir
.
IntImm
):
return
a
.
value
==
b
.
value
return
a
.
value
==
b
.
value
...
...
tilelang/carver/template/base.py
View file @
29051439
...
@@ -2,7 +2,12 @@
...
@@ -2,7 +2,12 @@
from
abc
import
ABC
,
abstractmethod
# For defining abstract base classes
from
abc
import
ABC
,
abstractmethod
# For defining abstract base classes
from
dataclasses
import
dataclass
,
field
# For defining data classes
from
dataclasses
import
dataclass
,
field
# For defining data classes
from
..arch
import
(
# Import architecture-related utilities and classes
from
..arch
import
(
# Import architecture-related utilities and classes
TileDevice
,
is_volta_arch
,
is_ampere_arch
,
is_cdna_arch
,
auto_infer_current_arch
)
TileDevice
,
is_volta_arch
,
is_ampere_arch
,
is_cdna_arch
,
auto_infer_current_arch
,
)
from
..roller.hint
import
Hint
# Import the Hint class
from
..roller.hint
import
Hint
# Import the Hint class
from
..roller.node
import
OutputNode
# Import the OutputNode class
from
..roller.node
import
OutputNode
# Import the OutputNode class
from
tvm.tir
import
PrimFunc
# Import PrimFunc for handling tensor IR functions
from
tvm.tir
import
PrimFunc
# Import PrimFunc for handling tensor IR functions
...
@@ -41,7 +46,7 @@ class BaseTemplate(ABC):
...
@@ -41,7 +46,7 @@ class BaseTemplate(ABC):
"""
"""
pass
pass
def
with_arch
(
self
,
arch
:
TileDevice
)
->
'
BaseTemplate
'
:
def
with_arch
(
self
,
arch
:
TileDevice
)
->
"
BaseTemplate
"
:
"""
"""
Sets the architecture for this template and returns itself.
Sets the architecture for this template and returns itself.
...
@@ -109,7 +114,7 @@ class BaseTemplate(ABC):
...
@@ -109,7 +114,7 @@ class BaseTemplate(ABC):
"""
"""
raise
NotImplementedError
(
"initialize_function is not implemented"
)
raise
NotImplementedError
(
"initialize_function is not implemented"
)
def
set_function
(
self
,
func
:
PrimFunc
)
->
'
BaseTemplate
'
:
def
set_function
(
self
,
func
:
PrimFunc
)
->
"
BaseTemplate
"
:
"""
"""
Sets the function for this template and returns itself.
Sets the function for this template and returns itself.
...
@@ -122,7 +127,7 @@ class BaseTemplate(ABC):
...
@@ -122,7 +127,7 @@ class BaseTemplate(ABC):
self
.
_func
=
func
self
.
_func
=
func
return
self
return
self
def
set_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
])
->
'
BaseTemplate
'
:
def
set_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
])
->
"
BaseTemplate
"
:
"""
"""
Sets the output nodes for this template and returns itself.
Sets the output nodes for this template and returns itself.
...
...
tilelang/carver/template/conv.py
View file @
29051439
...
@@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate):
...
@@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype (str): Data type used for accumulation.
accum_dtype (str): Data type used for accumulation.
with_bias (bool): Whether to add a bias term.
with_bias (bool): Whether to add a bias term.
"""
"""
# Operation-related configuration parameters
# Operation-related configuration parameters
N
:
int
# The number of input samples processed simultaneously in a batch.
N
:
int
# The number of input samples processed simultaneously in a batch.
C
:
int
# The number of input feature maps.
C
:
int
# The number of input feature maps.
...
@@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate):
...
@@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate):
AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers.
AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers.
"""
"""
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
self
.
N
,
self
.
C
,
self
.
H
,
self
.
W
,
self
.
F
,
self
.
K
,
self
.
S
,
self
.
D
,
self
.
P
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
self
.
N
,
self
.
C
,
self
.
H
,
self
.
W
,
self
.
F
,
self
.
K
,
self
.
S
,
self
.
D
,
self
.
P
assert
(
isinstance
(
N
,
int
)
and
isinstance
(
C
,
int
)
and
isinstance
(
H
,
int
)
and
assert
(
isinstance
(
W
,
int
)
and
isinstance
(
F
,
int
)
and
isinstance
(
K
,
int
)
and
isinstance
(
N
,
int
)
isinstance
(
S
,
int
)
and
isinstance
(
D
,
int
)
and
and
isinstance
(
C
,
int
)
isinstance
(
P
,
int
)),
"Only Support Integer Params"
and
isinstance
(
H
,
int
)
assert
(
N
>
0
and
C
>
0
and
H
>
0
and
W
>
0
and
F
>
0
and
K
>
0
and
S
>
0
and
D
>
0
and
and
isinstance
(
W
,
int
)
P
>
0
),
"Params should be positive"
and
isinstance
(
F
,
int
)
and
isinstance
(
K
,
int
)
and
isinstance
(
S
,
int
)
and
isinstance
(
D
,
int
)
and
isinstance
(
P
,
int
)
),
"Only Support Integer Params"
assert
N
>
0
and
C
>
0
and
H
>
0
and
W
>
0
and
F
>
0
and
K
>
0
and
S
>
0
and
D
>
0
and
P
>
0
,
"Params should be positive"
# Load configuration parameters
# Load configuration parameters
in_dtype
,
out_dtype
,
accum_dtype
=
self
.
in_dtype
,
self
.
out_dtype
,
self
.
accum_dtype
in_dtype
,
out_dtype
,
accum_dtype
=
self
.
in_dtype
,
self
.
out_dtype
,
self
.
accum_dtype
...
@@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate):
...
@@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate):
te
.
if_then_else
(
te
.
if_then_else
(
te
.
all
(
h_in
>=
0
,
h_in
<
H
,
w_in
>=
0
,
w_in
<
W
),
te
.
all
(
h_in
>=
0
,
h_in
<
H
,
w_in
>=
0
,
w_in
<
W
),
A
[
n
,
h_in
,
w_in
,
c
].
astype
(
accum_dtype
)
*
B
[
kh
,
kw
,
c
,
f
].
astype
(
accum_dtype
),
A
[
n
,
h_in
,
w_in
,
c
].
astype
(
accum_dtype
)
*
B
[
kh
,
kw
,
c
,
f
].
astype
(
accum_dtype
),
tir
.
const
(
0
,
accum_dtype
)),
tir
.
const
(
0
,
accum_dtype
),
axis
=
[
kh
,
kw
,
c
])
),
axis
=
[
kh
,
kw
,
c
],
)
# Compute convolution result
# Compute convolution result
C
=
te
.
compute
(
C
=
te
.
compute
(
...
...
tilelang/carver/template/flashattention.py
View file @
29051439
...
@@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_
...
@@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_
@
dataclass
@
dataclass
class
FlashAttentionTemplate
(
BaseTemplate
):
class
FlashAttentionTemplate
(
BaseTemplate
):
_output_nodes
:
list
[
OutputNode
]
=
None
_output_nodes
:
list
[
OutputNode
]
=
None
# Operation-related configuration parameters
# Operation-related configuration parameters
...
@@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate):
...
@@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate):
"""
"""
A_indices
=
[
b
,
i
,
k
]
A_indices
=
[
b
,
i
,
k
]
B_indices
=
[
b
,
j
,
k
]
B_indices
=
[
b
,
j
,
k
]
return
te
.
sum
(
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
# Compute matrix multiplication result
C
=
te
.
compute
(
C
=
te
.
compute
(
...
...
tilelang/carver/template/gemv.py
View file @
29051439
...
@@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate):
...
@@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate):
N
,
K
=
self
.
N
,
self
.
K
N
,
K
=
self
.
N
,
self
.
K
# Ensure M, N, K are valid positive integers
# Ensure M, N, K are valid positive integers
assert
(
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
assert
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Only Support Integer M, N, K"
isinstance
(
K
,
int
)),
"Only Support Integer M, N, K"
assert
M
>
0
and
N
>
0
and
K
>
0
,
"M, N, K should be positive"
assert
(
M
>
0
and
N
>
0
and
K
>
0
),
"M, N, K should be positive"
# Load configuration parameters
# Load configuration parameters
trans_B
=
self
.
trans_B
trans_B
=
self
.
trans_B
...
@@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate):
...
@@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate):
"""
"""
A_indices
=
[
i
,
k
]
A_indices
=
[
i
,
k
]
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
return
te
.
sum
(
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
# Compute matrix multiplication result
C
=
te
.
compute
(
C
=
te
.
compute
(
...
...
tilelang/carver/template/general_reduce.py
View file @
29051439
...
@@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func
...
@@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func
@
dataclass
@
dataclass
class
GeneralReductionTemplate
(
BaseTemplate
):
class
GeneralReductionTemplate
(
BaseTemplate
):
# OP Related Config
# OP Related Config
structure
:
str
|
list
[
str
]
=
None
structure
:
str
|
list
[
str
]
=
None
shape
:
list
[
int
]
=
None
shape
:
list
[
int
]
=
None
dtype
:
str
=
"float16"
dtype
:
str
=
"float16"
def
get_hardware_aware_configs
(
self
,
arch
:
TileDevice
=
None
,
topk
:
int
=
10
)
->
list
[
Hint
]:
def
get_hardware_aware_configs
(
self
,
arch
:
TileDevice
=
None
,
topk
:
int
=
10
)
->
list
[
Hint
]:
roller_hints
=
get_roller_hints_from_func
(
roller_hints
=
get_roller_hints_from_func
(
self
.
_func
,
arch
=
arch
,
topk
=
topk
,
allow_gemv
=
False
)
self
.
_func
,
arch
=
arch
,
topk
=
topk
,
allow_gemv
=
False
)
return
roller_hints
return
roller_hints
def
initialize_function
(
self
)
->
None
:
def
initialize_function
(
self
)
->
None
:
...
@@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate):
...
@@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate):
spatial_axes
=
[]
spatial_axes
=
[]
reduce_axes
=
[]
reduce_axes
=
[]
for
i
,
axis_type
in
enumerate
(
self
.
structure
):
for
i
,
axis_type
in
enumerate
(
self
.
structure
):
if
axis_type
.
upper
()
==
'S'
:
if
axis_type
.
upper
()
==
"S"
:
spatial_axes
.
append
((
i
,
self
.
shape
[
i
]))
spatial_axes
.
append
((
i
,
self
.
shape
[
i
]))
elif
axis_type
.
upper
()
==
'R'
:
elif
axis_type
.
upper
()
==
"R"
:
reduce_axes
.
append
((
i
,
self
.
shape
[
i
]))
reduce_axes
.
append
((
i
,
self
.
shape
[
i
]))
else
:
else
:
raise
ValueError
(
f
"Unrecognized axis type '
{
axis_type
}
', only 'S'/'R' allowed."
)
raise
ValueError
(
f
"Unrecognized axis type '
{
axis_type
}
', only 'S'/'R' allowed."
)
...
@@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate):
...
@@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate):
# Walk through the structure in order
# Walk through the structure in order
for
axis_type
in
self
.
structure
:
for
axis_type
in
self
.
structure
:
if
axis_type
.
upper
()
==
'S'
:
if
axis_type
.
upper
()
==
"S"
:
# use the next spatial_indices item
# use the next spatial_indices item
full_index
.
append
(
spatial_indices
[
spatial_iter
])
full_index
.
append
(
spatial_indices
[
spatial_iter
])
spatial_iter
+=
1
spatial_iter
+=
1
...
...
tilelang/carver/template/matmul.py
View file @
29051439
...
@@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate):
...
@@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate):
M
,
N
,
K
=
self
.
M
,
self
.
N
,
self
.
K
M
,
N
,
K
=
self
.
M
,
self
.
N
,
self
.
K
# Ensure M, N, K are valid positive integers
# Ensure M, N, K are valid positive integers
assert
(
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
assert
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Only Support Integer M, N, K"
isinstance
(
K
,
int
)),
"Only Support Integer M, N, K"
assert
M
>
0
and
N
>
0
and
K
>
0
,
"M, N, K should be positive"
assert
(
M
>
0
and
N
>
0
and
K
>
0
),
"M, N, K should be positive"
# Load configuration parameters
# Load configuration parameters
trans_A
,
trans_B
=
self
.
trans_A
,
self
.
trans_B
trans_A
,
trans_B
=
self
.
trans_A
,
self
.
trans_B
...
@@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate):
...
@@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate):
"""
"""
A_indices
=
[
i
,
k
]
if
not
trans_A
else
[
k
,
i
]
# Adjust indexing if A is transposed
A_indices
=
[
i
,
k
]
if
not
trans_A
else
[
k
,
i
]
# Adjust indexing if A is transposed
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
# Adjust indexing if B is transposed
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
# Adjust indexing if B is transposed
return
te
.
sum
(
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
# Compute matrix multiplication result
C
=
te
.
compute
(
C
=
te
.
compute
(
...
...
tilelang/carver/utils.py
View file @
29051439
...
@@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
...
@@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
"""
"""
def
get_roller_hints_from_func
(
func_or_module
:
tir
.
PrimFunc
|
IRModule
,
def
get_roller_hints_from_func
(
arch
:
TileDevice
,
func_or_module
:
tir
.
PrimFunc
|
IRModule
,
arch
:
TileDevice
,
topk
:
int
=
10
,
tensorcore_only
:
bool
=
False
,
allow_gemv
:
bool
=
False
topk
:
int
=
10
,
)
->
list
[
Hint
]
|
None
:
tensorcore_only
:
bool
=
False
,
allow_gemv
:
bool
=
False
)
->
list
[
Hint
]
|
None
:
func
=
None
func
=
None
if
isinstance
(
func_or_module
,
tir
.
PrimFunc
):
if
isinstance
(
func_or_module
,
tir
.
PrimFunc
):
func
=
func_or_module
func
=
func_or_module
...
@@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
...
@@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
roller_hints
=
None
roller_hints
=
None
if
tensorcore_only
:
if
tensorcore_only
:
try
:
try
:
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
except
Exception
as
e_msg
:
except
Exception
as
e_msg
:
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
tags
=
None
tags
=
None
...
@@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
...
@@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
policy
=
DefaultPolicy
.
from_prim_func
(
func
=
func
,
arch
=
arch
)
policy
=
DefaultPolicy
.
from_prim_func
(
func
=
func
,
arch
=
arch
)
tensorized_func
=
None
tensorized_func
=
None
try
:
try
:
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
except
Exception
as
e_msg
:
except
Exception
as
e_msg
:
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
tags
=
None
tags
=
None
...
@@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
...
@@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
return
roller_hints
return
roller_hints
def
get_roller_hints_from_output_nodes
(
output_nodes
:
list
[
OutputNode
],
def
get_roller_hints_from_output_nodes
(
arch
:
TileDevice
,
output_nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
topk
:
int
=
10
,
extra_tags
:
list
[
str
]
|
None
=
None
topk
:
int
=
10
,
)
->
list
[
Hint
]
|
None
:
extra_tags
:
list
[
str
]
|
None
=
None
)
->
list
[
Hint
]
|
None
:
assert
isinstance
(
output_nodes
,
list
),
"The input should be a list of functions."
assert
isinstance
(
output_nodes
,
list
),
"The input should be a list of functions."
lints
=
[]
lints
=
[]
...
@@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
...
@@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
policy
=
TensorCorePolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
policy
=
TensorCorePolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
lints
=
policy
.
emit_config
(
topk
)
lints
=
policy
.
emit_config
(
topk
)
except
Exception
as
e_msg
:
except
Exception
as
e_msg
:
logger
.
debug
(
f
"Generate hints from output nodes failed:
{
e_msg
}
"
,
logger
.
debug
(
f
"Generate hints from output nodes failed:
{
e_msg
}
"
,
"fallback to default policy"
)
"fallback to default policy"
)
if
len
(
lints
)
==
0
:
if
len
(
lints
)
==
0
:
policy
=
DefaultPolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
policy
=
DefaultPolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
...
@@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
...
@@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
def
retrieve_func_from_module
(
ir_module
:
IRModule
)
->
PrimFunc
:
def
retrieve_func_from_module
(
ir_module
:
IRModule
)
->
PrimFunc
:
if
not
isinstance
(
ir_module
,
IRModule
):
if
not
isinstance
(
ir_module
,
IRModule
):
raise
ValueError
(
"Not supported type: "
,
type
(
ir_module
))
raise
ValueError
(
"Not supported type: "
,
type
(
ir_module
))
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
(
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
"The optimized module should only have one global variable for default schedule."
"The optimized module should only have one global variable for default schedule."
)
func
=
list
(
ir_module
.
functions
.
values
())[
0
]
func
=
list
(
ir_module
.
functions
.
values
())[
0
]
return
func
return
func
tilelang/contrib/cc.py
View file @
29051439
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
"""Util to invoke C/C++ compilers in the system."""
"""Util to invoke C/C++ compilers in the system."""
import
functools
import
functools
import
os
import
os
import
shutil
import
shutil
...
@@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils
...
@@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils
def
_is_linux_like
():
def
_is_linux_like
():
return
(
sys
.
platform
==
"darwin"
or
sys
.
platform
.
startswith
(
"linux"
)
or
return
sys
.
platform
==
"darwin"
or
sys
.
platform
.
startswith
(
"linux"
)
or
sys
.
platform
.
startswith
(
"freebsd"
)
sys
.
platform
.
startswith
(
"freebsd"
))
def
_is_windows_like
():
def
_is_windows_like
():
...
@@ -90,7 +90,7 @@ def get_cplus_compiler():
...
@@ -90,7 +90,7 @@ def get_cplus_compiler():
def
is_darwin
():
def
is_darwin
():
return
platform
.
system
()
==
'
Darwin
'
return
platform
.
system
()
==
"
Darwin
"
def
create_shared
(
output
,
objects
,
options
=
None
,
cc
=
None
,
cwd
=
None
,
ccache_env
=
None
):
def
create_shared
(
output
,
objects
,
options
=
None
,
cc
=
None
,
cwd
=
None
,
ccache_env
=
None
):
...
@@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
...
@@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared
.
get_target_triple
=
get_target_by_dump_machine
(
os
.
environ
.
get
(
"CXX"
,
get_cc
()))
create_shared
.
get_target_triple
=
get_target_by_dump_machine
(
os
.
environ
.
get
(
"CXX"
,
get_cc
()))
def
cross_compiler
(
compile_func
,
def
cross_compiler
(
compile_func
,
options
=
None
,
output_format
=
None
,
get_target_triple
=
None
,
add_files
=
None
):
options
=
None
,
output_format
=
None
,
get_target_triple
=
None
,
add_files
=
None
):
"""Create a cross compiler function by specializing compile_func with options.
"""Create a cross compiler function by specializing compile_func with options.
This function can be used to construct compile functions that
This function can be used to construct compile functions that
...
@@ -363,13 +359,7 @@ def cross_compiler(compile_func,
...
@@ -363,13 +359,7 @@ def cross_compiler(compile_func,
return
_fcompile
return
_fcompile
def
_linux_compile
(
output
,
def
_linux_compile
(
output
,
objects
,
options
,
compile_cmd
,
cwd
=
None
,
ccache_env
=
None
,
compile_shared
=
False
):
objects
,
options
,
compile_cmd
,
cwd
=
None
,
ccache_env
=
None
,
compile_shared
=
False
):
cmd
=
[
compile_cmd
]
cmd
=
[
compile_cmd
]
if
compile_cmd
!=
"nvcc"
:
if
compile_cmd
!=
"nvcc"
:
if
compile_shared
or
output
.
endswith
(
".so"
)
or
output
.
endswith
(
".dylib"
):
if
compile_shared
or
output
.
endswith
(
".so"
)
or
output
.
endswith
(
".dylib"
):
...
@@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
...
@@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
raise
ValueError
(
"ccache not found"
)
raise
ValueError
(
"ccache not found"
)
try
:
try
:
proc
=
subprocess
.
Popen
(
proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
cwd
=
cwd
,
env
=
env
)
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
cwd
=
cwd
,
env
=
env
)
(
out
,
_
)
=
proc
.
communicate
()
(
out
,
_
)
=
proc
.
communicate
()
except
FileNotFoundError
:
except
FileNotFoundError
:
raise
RuntimeError
(
"Can not find the LLVM clang for Windows clang.exe)."
raise
RuntimeError
(
"Make sure it's installed"
"Can not find the LLVM clang for Windows clang.exe)."
" and the installation directory is in the %PATH% environment "
"Make sure it's installed"
"variable. Prebuilt binaries can be found at: https://llvm.org/"
)
\
" and the installation directory is in the %PATH% environment "
from
None
"variable. Prebuilt binaries can be found at: https://llvm.org/"
)
from
None
if
proc
.
returncode
!=
0
:
if
proc
.
returncode
!=
0
:
msg
=
"Compilation error:
\n
"
msg
=
"Compilation error:
\n
"
msg
+=
" "
.
join
(
cmd
)
+
"
\n
"
msg
+=
" "
.
join
(
cmd
)
+
"
\n
"
...
...
tilelang/contrib/dlpack.py
View file @
29051439
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# specific language governing permissions and limitations
# under the License.
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from
tvm
import
runtime
from
tvm
import
runtime
...
@@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
...
@@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
def
adapt_tensor
(
arg
):
def
adapt_tensor
(
arg
):
if
isinstance
(
arg
,
tensor_type
):
if
isinstance
(
arg
,
tensor_type
):
if
arg
.
dtype
in
{
if
arg
.
dtype
in
{
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
}:
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
torch
.
float8_e5m2fnuz
}:
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
))
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
))
return
arg
return
arg
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment