Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
176 additions
and
316 deletions
+176
-316
tilelang/carver/arch/metal.py
tilelang/carver/arch/metal.py
+2
-3
tilelang/carver/common_schedules.py
tilelang/carver/common_schedules.py
+1
-0
tilelang/carver/matmul_analysis.py
tilelang/carver/matmul_analysis.py
+29
-50
tilelang/carver/roller/bestfit.py
tilelang/carver/roller/bestfit.py
+2
-6
tilelang/carver/roller/hint.py
tilelang/carver/roller/hint.py
+3
-2
tilelang/carver/roller/node.py
tilelang/carver/roller/node.py
+23
-42
tilelang/carver/roller/policy/default.py
tilelang/carver/roller/policy/default.py
+28
-56
tilelang/carver/roller/policy/tensorcore.py
tilelang/carver/roller/policy/tensorcore.py
+16
-39
tilelang/carver/roller/rasterization.py
tilelang/carver/roller/rasterization.py
+0
-2
tilelang/carver/roller/shape_inference/common.py
tilelang/carver/roller/shape_inference/common.py
+1
-4
tilelang/carver/roller/shape_inference/tir.py
tilelang/carver/roller/shape_inference/tir.py
+9
-32
tilelang/carver/template/base.py
tilelang/carver/template/base.py
+9
-4
tilelang/carver/template/conv.py
tilelang/carver/template/conv.py
+17
-8
tilelang/carver/template/flashattention.py
tilelang/carver/template/flashattention.py
+1
-5
tilelang/carver/template/gemv.py
tilelang/carver/template/gemv.py
+3
-6
tilelang/carver/template/general_reduce.py
tilelang/carver/template/general_reduce.py
+4
-6
tilelang/carver/template/matmul.py
tilelang/carver/template/matmul.py
+3
-6
tilelang/carver/utils.py
tilelang/carver/utils.py
+10
-17
tilelang/contrib/cc.py
tilelang/contrib/cc.py
+12
-22
tilelang/contrib/dlpack.py
tilelang/contrib/dlpack.py
+3
-6
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/carver/arch/metal.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool:
class
METAL
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
target
=
Target
(
target
)
...
...
@@ -16,6 +15,6 @@ class METAL(TileDevice):
__all__
=
[
'
is_metal_arch
'
,
'
METAL
'
,
"
is_metal_arch
"
,
"
METAL
"
,
]
tilelang/carver/common_schedules.py
View file @
29051439
...
...
@@ -19,6 +19,7 @@
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
from
typing
import
Callable
from
tvm
import
tir
...
...
tilelang/carver/matmul_analysis.py
View file @
29051439
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
enum
import
Enum
...
...
@@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block
return
block
def
find_arg_idx_from_buffer_chain
(
sch
:
tir
.
Schedule
,
main_block
:
tir
.
schedule
.
BlockRV
,
buffer
:
tir
.
Buffer
)
->
int
:
def
find_arg_idx_from_buffer_chain
(
sch
:
tir
.
Schedule
,
main_block
:
tir
.
schedule
.
BlockRV
,
buffer
:
tir
.
Buffer
)
->
int
:
"""traverse to find the arg index from the buffer"""
producers
=
sch
.
get_producers
(
main_block
)
...
...
@@ -226,9 +226,7 @@ def make_iter_fusion_index_map(
else
:
fused_iters
[
trait
.
kind
]
=
v_i
final_indices
:
list
[
tir
.
PrimExpr
]
=
[
fused_iters
.
get
(
kind
,
tir
.
IntImm
(
traits
[
0
].
extent
.
dtype
,
0
))
for
kind
in
kind_order
]
final_indices
:
list
[
tir
.
PrimExpr
]
=
[
fused_iters
.
get
(
kind
,
tir
.
IntImm
(
traits
[
0
].
extent
.
dtype
,
0
))
for
kind
in
kind_order
]
return
tir
.
IndexMap
(
input_iters
,
final_indices
,
None
)
...
...
@@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
return
A_traits
,
B_traits
,
C_traits
,
block_traits
def
get_index_map
(
block
:
tir
.
Block
,
layout
:
list
[
str
]
|
None
=
None
)
->
tuple
[
tir
.
IndexMap
,
...]
|
None
:
def
get_index_map
(
block
:
tir
.
Block
,
layout
:
list
[
str
]
|
None
=
None
)
->
tuple
[
tir
.
IndexMap
,
...]
|
None
:
"""Get index maps for the block
Parameters
...
...
@@ -343,10 +340,7 @@ def get_index_map(block: tir.Block,
return
axes
def
is_common_reduce
(
var
:
Var
)
->
bool
:
for
iter_var
in
block
.
iter_vars
:
if
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
:
return
True
return
False
return
any
(
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
for
iter_var
in
block
.
iter_vars
)
def
has_common_reduce
(
var
:
Var
)
->
bool
:
vars
=
collect_vars_from_expr
(
var
)
...
...
@@ -384,17 +378,17 @@ def get_index_map(block: tir.Block,
if
kind
==
"C"
:
return
[
IterKind
.
kIter_S
,
primary_iter
,
secondary_iter
]
else
:
return
([
IterKind
.
kIter_S
,
spatial_iter
,
reduction_iter
]
if
check_last_trait
(
region
)
else
[
IterKind
.
kIter_S
,
reduction_iter
,
spatial_iter
])
return
(
[
IterKind
.
kIter_S
,
spatial_iter
,
reduction_iter
]
if
check_last_trait
(
region
)
else
[
IterKind
.
kIter_S
,
reduction_iter
,
spatial_iter
]
)
else
:
raise
ValueError
(
f
"Unknown layout
{
layout
}
"
)
A_index_map
=
make_iter_fusion_index_map
(
A_traits
,
infer_layout
(
layout
[
0
],
block
.
reads
[
0
].
region
,
kind
=
"A"
))
B_index_map
=
make_iter_fusion_index_map
(
B_traits
,
infer_layout
(
layout
[
1
],
block
.
reads
[
1
].
region
,
kind
=
"B"
))
C_index_map
=
make_iter_fusion_index_map
(
C_traits
,
infer_layout
(
layout
[
2
],
block
.
writes
[
0
].
region
,
kind
=
"C"
))
A_index_map
=
make_iter_fusion_index_map
(
A_traits
,
infer_layout
(
layout
[
0
],
block
.
reads
[
0
].
region
,
kind
=
"A"
))
B_index_map
=
make_iter_fusion_index_map
(
B_traits
,
infer_layout
(
layout
[
1
],
block
.
reads
[
1
].
region
,
kind
=
"B"
))
C_index_map
=
make_iter_fusion_index_map
(
C_traits
,
infer_layout
(
layout
[
2
],
block
.
writes
[
0
].
region
,
kind
=
"C"
))
matmul_index_map
=
make_iter_fusion_index_map
(
block_traits
,
...
...
@@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None:
has_uint_input
=
any
(
"uint"
in
str
(
region
.
buffer
.
dtype
)
for
region
in
block_stmt
.
reads
)
if
not
has_uint_input
:
return
False
return
not
(
len
(
block_stmt
.
writes
)
!=
1
or
"float"
not
in
str
(
block_stmt
.
writes
[
0
].
buffer
.
dtype
))
return
not
(
len
(
block_stmt
.
writes
)
!=
1
or
"float"
not
in
str
(
block_stmt
.
writes
[
0
].
buffer
.
dtype
))
dequantize_blocks
=
[
block
for
block
in
blocks
if
is_dequantize
(
block
)]
return
dequantize_blocks
[
0
]
if
len
(
dequantize_blocks
)
==
1
else
None
...
...
@@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
return
None
axes
.
extend
(
undefined_vars
(
r
.
min
))
# remove trivial axis
trivial_vars
=
set
(
iter_var
.
var
for
iter_var
in
block_stmt
.
iter_vars
if
_is_one
(
iter_var
.
dom
.
extent
))
trivial_vars
=
set
(
iter_var
.
var
for
iter_var
in
block_stmt
.
iter_vars
if
_is_one
(
iter_var
.
dom
.
extent
))
axes
=
[
axis
for
axis
in
axes
if
axis
not
in
trivial_vars
]
# remove duplicate axis
axes
=
[
var
for
i
,
var
in
enumerate
(
axes
)
if
i
==
0
or
var
!=
axes
[
i
-
1
]]
...
...
@@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
lhs_access_vars
=
get_access_vars
(
block_stmt
.
reads
[
0
].
region
)[
-
2
:]
rhs_access_vars
=
get_access_vars
(
block_stmt
.
writes
[
0
].
region
)[
-
2
:]
is_identity
=
list
(
lhs_access_vars
)
==
list
(
rhs_access_vars
)
is_transpose
=
list
(
lhs_access_vars
)
!=
list
(
rhs_access_vars
)
and
set
(
lhs_access_vars
)
==
set
(
rhs_access_vars
)
is_transpose
=
list
(
lhs_access_vars
)
!=
list
(
rhs_access_vars
)
and
set
(
lhs_access_vars
)
==
set
(
rhs_access_vars
)
return
is_identity
,
is_transpose
...
...
@@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]
return
result_blocks
def
normalize_to_matmul
(
sch
:
tir
.
Schedule
,
main_block
:
BlockRV
,
layout
:
list
[
str
]
|
None
=
None
)
->
tir
.
Schedule
|
None
:
def
normalize_to_matmul
(
sch
:
tir
.
Schedule
,
main_block
:
BlockRV
,
layout
:
list
[
str
]
|
None
=
None
)
->
tir
.
Schedule
|
None
:
if
layout
is
None
:
layout
=
[
"n"
,
"t"
,
"n"
]
block_stmt
=
sch
.
get
(
main_block
)
...
...
@@ -526,7 +515,7 @@ def get_tensorized_func_and_tags(
allow_gemv
:
bool
=
False
,
)
->
tuple
[
tir
.
PrimFunc
,
dict
[
str
,
list
[
int
]
|
int
]]:
"""
transform function to matmul if necessary (e.g. transform conv2d with im2col)
transform function to matmul if necessary (e.g. transform conv2d with im2col)
"""
if
layout
is
None
:
layout
=
[
"a"
,
"a"
,
"a"
]
...
...
@@ -543,10 +532,7 @@ def get_tensorized_func_and_tags(
conditions
=
[]
conditions
.
append
(
len
(
block_stmt
.
reads
)
==
2
)
conditions
.
append
(
len
(
block_stmt
.
writes
)
==
1
)
conditions
.
append
(
len
(
collect_block_iter_vars_used_in_access_region
(
block_stmt
,
block_stmt
.
writes
[
0
].
region
))
>
0
)
conditions
.
append
(
len
(
collect_block_iter_vars_used_in_access_region
(
block_stmt
,
block_stmt
.
writes
[
0
].
region
))
>
0
)
return
all
(
conditions
)
# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
...
...
@@ -592,10 +578,7 @@ def get_tensorized_func_and_tags(
return
axes
def
is_common_reduce
(
var
:
Var
)
->
bool
:
for
iter_var
in
block_stmt
.
iter_vars
:
if
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
:
return
True
return
False
return
any
(
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
for
iter_var
in
block_stmt
.
iter_vars
)
def
has_common_reduce
(
var
:
Var
)
->
bool
:
vars
=
collect_vars_from_expr
(
var
)
...
...
@@ -626,7 +609,7 @@ def get_tensorized_func_and_tags(
# When the func is a dequantize like ops, we should consider the M
require_block_reduce
=
False
# And we only support float16 for now
if
(
hasattr
(
func
.
attrs
,
"dequantize_info"
)
and
in_dtype
in
[
"bfloat16"
,
"float16"
]
)
:
if
hasattr
(
func
.
attrs
,
"dequantize_info"
)
and
in_dtype
in
[
"bfloat16"
,
"float16"
]:
for
arg
in
func
.
params
:
inp_shape
=
func
.
buffer_map
[
arg
].
shape
M
=
inp_shape
[
0
]
...
...
@@ -645,9 +628,7 @@ def get_tensorized_func_and_tags(
if
target
.
kind
.
name
==
"cuda"
and
check_sm_version
(
target
.
arch
)
>=
70
:
in_dtype
,
out_dtype
=
get_in_out_dtypes
(
block_stmt
)
if
not
is_tensorcore_supported_precision
(
in_dtype
,
out_dtype
,
arch
=
get_arch
(
target
)):
logger
.
debug
(
f
"The input and output dtype (
{
in_dtype
}
,
{
out_dtype
}
)is not supported by tensorcore"
)
logger
.
debug
(
f
"The input and output dtype (
{
in_dtype
}
,
{
out_dtype
}
)is not supported by tensorcore"
)
return
func
,
None
# reindex and transform functions
...
...
@@ -676,7 +657,7 @@ def get_tensorized_func_and_tags(
else
:
raise
ValueError
(
f
"Unknown IterVar type
{
iter_type
}
"
)
if
(
isinstance
(
extent
,
tir
.
expr
.
IntImm
)
and
extent
.
value
<
minimal_tensorize_threshold
)
:
if
isinstance
(
extent
,
tir
.
expr
.
IntImm
)
and
extent
.
value
<
minimal_tensorize_threshold
:
return
func
,
None
tags
=
analysis_tensorcore_tags
(
sch
,
main_block
,
target
)
return
sch
.
mod
[
"main"
],
tags
...
...
@@ -686,8 +667,10 @@ def get_tensorized_func_and_tags(
def
get_propagate_map
(
trans
:
bool
=
True
,
dtype
=
"float16"
,
matrix_name
=
"A"
,
index_dtype
=
"int32"
):
from
bitblas.tl.mma_layout
import
(
# pylint: disable=import-outside-toplevel
ldmatrix_32x8_to_shared_16x16_layout
,
ldmatrix_trans_32x8_to_shared_16x16_layout
,
ldmatrix_32x16_to_shared_16x32_layout_a
,
ldmatrix_32x16_to_shared_16x32_layout_b
,
ldmatrix_32x8_to_shared_16x16_layout
,
ldmatrix_trans_32x8_to_shared_16x16_layout
,
ldmatrix_32x16_to_shared_16x32_layout_a
,
ldmatrix_32x16_to_shared_16x32_layout_b
,
)
assert
dtype
in
[
...
...
@@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
return
ldmatrix_layout
(
thread_id
,
local_id
)
if
dtype
in
[
"bfloat16"
,
"float16"
]:
ldmatrix_index_map
=
(
ldmatrix_trans_permutation_16x16_32x8_16x16
if
trans
else
ldmatrix_permutation_16x16_32x8_16x16
)
ldmatrix_index_map
=
ldmatrix_trans_permutation_16x16_32x8_16x16
if
trans
else
ldmatrix_permutation_16x16_32x8_16x16
else
:
ldmatrix_index_map
=
ldmatrix_permutation_16x32_32x16_32x16
...
...
@@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
# Ladder weight propagation, which can be used to avoid the ldmatrix
# Instructions.
def
get_ladder_stage3_map
(
dtype
=
"float16"
,
index_dtype
=
"int32"
):
def
shared_32x8_to_mma_32x8_layout
(
i
,
j
):
thread_id
=
(
i
%
8
)
*
4
+
(
j
//
2
)
local_id
=
(
i
//
8
)
*
2
+
(
j
%
2
)
...
...
@@ -837,8 +817,7 @@ def layout_propagate_chain(
scaling_factor
=
1
for
i
,
j
in
zip
(
write
.
buffer
.
shape
,
read
.
buffer
.
shape
):
scaling_factor
*=
i
//
j
final_indices
=
list
(
index_map
.
map_indices
(
tmp_index_map
.
map_indices
(
write_indices
)))
final_indices
=
list
(
index_map
.
map_indices
(
tmp_index_map
.
map_indices
(
write_indices
)))
final_indices
[
-
1
]
=
final_indices
[
-
1
]
//
scaling_factor
index_map
=
IndexMap
(
write_indices
,
...
...
tilelang/carver/roller/bestfit.py
View file @
29051439
...
...
@@ -2,7 +2,6 @@
class
Block
:
def
__init__
(
self
,
start
,
end
,
is_free
):
self
.
start
=
start
self
.
end
=
end
...
...
@@ -21,7 +20,6 @@ class Block:
class
BestFit
:
def
__init__
(
self
,
align
=
32
):
self
.
limit
=
0
self
.
list
=
[]
...
...
@@ -31,16 +29,14 @@ class BestFit:
size
=
(
size
+
self
.
align
-
1
)
//
self
.
align
*
self
.
align
found
=
None
for
block
in
self
.
list
:
if
block
.
is_free
and
block
.
size
()
>=
size
and
(
not
found
or
found
.
size
()
>
block
.
size
()):
if
block
.
is_free
and
block
.
size
()
>=
size
and
(
not
found
or
found
.
size
()
>
block
.
size
()):
found
=
block
if
found
:
found
.
is_free
=
False
remain
=
found
.
size
()
-
size
if
remain
!=
0
:
found
.
end
-=
remain
self
.
list
.
insert
(
self
.
list
.
index
(
found
)
+
1
,
Block
(
found
.
end
,
found
.
end
+
remain
,
True
))
self
.
list
.
insert
(
self
.
list
.
index
(
found
)
+
1
,
Block
(
found
.
end
,
found
.
end
+
remain
,
True
))
return
found
elif
len
(
self
.
list
)
>
0
and
self
.
list
[
-
1
].
is_free
:
add
=
size
-
self
.
list
[
-
1
].
size
()
...
...
tilelang/carver/roller/hint.py
View file @
29051439
"""Hint definition for schedule"""
from
tvm
import
DataType
from
.
import
PrimFuncNode
import
numpy
as
np
...
...
@@ -60,7 +61,7 @@ class Stride:
strided_elem
=
original_shape
else
:
assert
self
.
ax
<
len
(
shape
)
strided_elem
=
np
.
prod
(
shape
[
0
:
self
.
ax
+
1
])
*
self
.
stride
strided_elem
=
np
.
prod
(
shape
[
0
:
self
.
ax
+
1
])
*
self
.
stride
assert
strided_elem
>=
original_shape
return
int
(
strided_elem
)
...
...
@@ -217,7 +218,7 @@ class Hint:
return
dic
@
classmethod
def
from_dict
(
cls
,
dic
:
dict
)
->
'
Hint
'
:
def
from_dict
(
cls
,
dic
:
dict
)
->
"
Hint
"
:
hint
=
cls
()
for
k
,
v
in
dic
.
items
():
setattr
(
hint
,
k
,
v
)
...
...
tilelang/carver/roller/node.py
View file @
29051439
"""PrimFunc Wrapper and Block information Analaysis"""
from
__future__
import
annotations
import
tvm
...
...
@@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func):
class
BlockAnalyzer
:
def
__init__
(
self
,
sch
)
->
None
:
self
.
sch
:
tir
.
Schedule
=
sch
self
.
block_infos
:
list
[
BlockInfo
]
=
normalize_prim_func
(
self
.
sch
)
...
...
@@ -92,7 +92,6 @@ class Edge:
class
Node
:
def
__init__
(
self
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"Node"
)
->
None
:
self
.
name
=
name
if
tags
is
None
:
...
...
@@ -177,7 +176,6 @@ class Node:
class
PlaceHolderNode
(
Node
):
def
__init__
(
self
,
name
=
""
):
super
().
__init__
(
name
=
"PlaceHolder_"
+
name
)
...
...
@@ -189,11 +187,7 @@ class PlaceHolderNode(Node):
class
PrimFuncNode
(
Node
):
def
__init__
(
self
,
prim_func
:
PrimFunc
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
)
->
None
:
def
__init__
(
self
,
prim_func
:
PrimFunc
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
)
->
None
:
super
().
__init__
(
tags
,
name
=
name
)
self
.
prim_func
=
self
.
_specialize_func
(
prim_func
)
self
.
sch
:
tir
.
Schedule
=
tir
.
Schedule
(
self
.
prim_func
)
...
...
@@ -227,7 +221,7 @@ class PrimFuncNode(Node):
for
dst_id
,
n
in
enumerate
(
inputs
):
if
isinstance
(
n
,
Node
):
n
=
(
n
,
0
)
assert
(
len
(
n
)
==
2
)
assert
len
(
n
)
==
2
src_node
,
src_id
=
n
[
0
],
n
[
1
]
edge
=
Edge
(
src_node
,
self
,
src_id
,
dst_id
)
self
.
_in_edges
.
append
(
edge
)
...
...
@@ -338,9 +332,8 @@ class PrimFuncNode(Node):
if
rstep
is
None
:
rstep
=
{}
shape
=
{
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
].
name
:
[
tvm
.
arith
.
ConstIntBound
(
0
,
val
-
1
)
for
val
in
tile
]
for
block
in
self
.
schedule_stages
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
].
name
:
[
tvm
.
arith
.
ConstIntBound
(
0
,
val
-
1
)
for
val
in
tile
]
for
block
in
self
.
schedule_stages
}
return
self
.
ana
.
infer
(
shape
,
rstep
,
targets
)
...
...
@@ -356,10 +349,7 @@ class PrimFuncNode(Node):
results
.
append
(
shapes
[
arg
.
name
])
continue
# should not exceed original shape
trimmed_shape
=
[
self
.
extent_wrapper
(
i
)
for
i
in
list
(
map
(
min
,
zip
(
shapes
[
arg
.
name
],
self
.
input_buffers
[
i
].
shape
)))
]
trimmed_shape
=
[
self
.
extent_wrapper
(
i
)
for
i
in
list
(
map
(
min
,
zip
(
shapes
[
arg
.
name
],
self
.
input_buffers
[
i
].
shape
)))]
results
.
append
(
trimmed_shape
)
return
results
...
...
@@ -380,10 +370,8 @@ class PrimFuncNode(Node):
propagate_shape
=
shapes
[
arg
.
name
]
buffer_shape
=
args
[
i
].
shape
if
len
(
buffer_shape
)
>
len
(
propagate_shape
):
buffer_shape
=
buffer_shape
[
-
len
(
propagate_shape
):]
trimmed_shape
=
[
self
.
extent_wrapper
(
j
)
for
j
in
list
(
map
(
min
,
zip
(
propagate_shape
,
buffer_shape
)))
]
buffer_shape
=
buffer_shape
[
-
len
(
propagate_shape
)
:]
trimmed_shape
=
[
self
.
extent_wrapper
(
j
)
for
j
in
list
(
map
(
min
,
zip
(
propagate_shape
,
buffer_shape
)))]
results
.
append
(
trimmed_shape
)
return
results
...
...
@@ -412,10 +400,7 @@ class PrimFuncNode(Node):
def
get_reduce_inputs_dtype
(
self
):
if
self
.
reduction_block
is
None
:
return
{}
return
{
b
.
name
:
tvm
.
DataType
(
b
.
dtype
)
for
b
in
self
.
block_analyzer
.
get_input_buffers
(
self
.
reduction_block
)
}
return
{
b
.
name
:
tvm
.
DataType
(
b
.
dtype
)
for
b
in
self
.
block_analyzer
.
get_input_buffers
(
self
.
reduction_block
)}
@
functools
.
lru_cache
def
infer_tensorcore_axis
(
self
)
->
tuple
[
int
]:
...
...
@@ -425,8 +410,7 @@ class PrimFuncNode(Node):
C_ax_m
,
C_ax_n
=
self
.
get_tag
(
"tensorcore_config"
)
wmma_m
,
wmma_n
,
wmma_k
=
[
16
,
16
,
16
]
# just for testing, any number is ok
output_buffer_shape
=
(
self
.
block_analyzer
.
sch
.
get
(
self
.
reduction_block
).
writes
[
0
].
buffer
.
shape
)
output_buffer_shape
=
self
.
block_analyzer
.
sch
.
get
(
self
.
reduction_block
).
writes
[
0
].
buffer
.
shape
valid_region
=
[]
for
region
in
output_buffer_shape
:
if
region
.
value
==
1
:
...
...
@@ -438,8 +422,7 @@ class PrimFuncNode(Node):
def
get_cl_shapes
(
c_ax_m
,
c_ax_n
,
num_nvalid_regions
):
spatial_dim
=
self
.
get_space_dim
()
assert
len
(
valid_region
)
==
len
(
spatial_dim
),
f
"
{
valid_region
}
mismatch with
{
spatial_dim
}
"
assert
len
(
valid_region
)
==
len
(
spatial_dim
),
f
"
{
valid_region
}
mismatch with
{
spatial_dim
}
"
cl_shapes
=
[
1
]
*
len
(
spatial_dim
)
cl_shapes
[
c_ax_m
-
num_nvalid_regions
]
=
wmma_m
cl_shapes
[
c_ax_n
-
num_nvalid_regions
]
=
wmma_n
...
...
@@ -467,9 +450,11 @@ class PrimFuncNode(Node):
shapes
,
_
=
self
.
propagate
(
shape
,
rstep
)
def
is_broadcast_pattern
(
buffer
,
output_buffer
):
return
(
buffer
in
self
.
args
and
len
(
shapes
[
output_buffer
.
name
])
>
len
(
shapes
[
buffer
.
name
])
and
np
.
prod
(
shapes
[
output_buffer
.
name
])
>
np
.
prod
(
shapes
[
buffer
.
name
]))
return
(
buffer
in
self
.
args
and
len
(
shapes
[
output_buffer
.
name
])
>
len
(
shapes
[
buffer
.
name
])
and
np
.
prod
(
shapes
[
output_buffer
.
name
])
>
np
.
prod
(
shapes
[
buffer
.
name
])
)
def
is_after_reduce_stage
(
block
):
if
not
self
.
reduction_block
:
...
...
@@ -491,8 +476,8 @@ class PrimFuncNode(Node):
output_buffer
=
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
]
for
buffer
in
self
.
block_analyzer
.
get_input_buffers
(
block
):
cache
=
buffer
.
name
not
in
cached_tensor
and
(
is_broadcast_pattern
(
buffer
,
output_buffer
)
or
self
.
block_analyzer
.
get_block_info
(
block
).
is_reduction
()
)
is_broadcast_pattern
(
buffer
,
output_buffer
)
or
self
.
block_analyzer
.
get_block_info
(
block
).
is_reduction
()
)
if
not
cache
:
continue
cached_tensor
.
append
(
buffer
.
name
)
...
...
@@ -500,8 +485,7 @@ class PrimFuncNode(Node):
continue
# cache after reduce op can often reuse buffer in reduce stage
if
buffer
.
name
in
stride_map
:
num_elem
=
stride_map
[
buffer
.
name
].
compute_elements_from_shape
(
shapes
[
buffer
.
name
])
num_elem
=
stride_map
[
buffer
.
name
].
compute_elements_from_shape
(
shapes
[
buffer
.
name
])
else
:
num_elem
=
np
.
prod
(
shapes
[
buffer
.
name
])
buffer_len
=
num_elem
*
int
((
tvm
.
DataType
(
buffer
.
dtype
).
bits
+
7
)
//
8
)
...
...
@@ -514,7 +498,6 @@ class PrimFuncNode(Node):
class
OutputNode
(
Node
):
def
__init__
(
self
,
node
,
id
=
0
):
super
().
__init__
(
name
=
"OutputNode"
)
# connect node and output node
...
...
@@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]:
input_ready_count
[
dst_node
]
=
len
(
dst_node
.
inputs
)
list_of_nodes
.
append
(
dst_node
)
input_ready_count
[
dst_node
]
-=
1
assert
(
input_ready_count
[
dst_node
]
>=
0
)
assert
input_ready_count
[
dst_node
]
>=
0
if
input_ready_count
[
dst_node
]
==
0
:
ready
.
append
(
dst_node
)
assert
(
len
(
list_of_nodes
)
==
len
(
output_list
)
)
assert
len
(
list_of_nodes
)
==
len
(
output_list
)
return
output_list
def
find_topo_sort_priority
(
output_node_list
)
->
list
[
Node
]:
import
sys
sys
.
setrecursionlimit
(
10000
)
def
topo_sort_get_layer
(
node
,
topo_layer
):
...
...
@@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
if
node
in
visited
:
return
visited
.
add
(
node
)
ordered_input_nodes
=
sorted
([
edge
.
src_node
for
edge
in
node
.
inputs
],
key
=
lambda
n
:
topo_layer
[
n
],
reverse
=
True
)
ordered_input_nodes
=
sorted
([
edge
.
src_node
for
edge
in
node
.
inputs
],
key
=
lambda
n
:
topo_layer
[
n
],
reverse
=
True
)
for
n
in
ordered_input_nodes
:
topo_sort_dfs
(
n
,
visited
,
topo_order
)
topo_order
.
append
(
node
)
...
...
@@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
def
find_topo_sort
(
output_node_list
)
->
list
[
Node
]:
def
topo_sort_dfs
(
node
,
visited
,
topo_order
):
if
node
in
visited
:
return
...
...
tilelang/carver/roller/policy/default.py
View file @
29051439
"""Policy for cuda core schedule"""
from
__future__
import
annotations
import
functools
import
math
...
...
@@ -36,20 +37,14 @@ class DefaultPolicy:
self
.
rasterization
=
NoRasterization
()
@
classmethod
def
from_prim_func
(
cls
,
func
:
tvm
.
tir
.
PrimFunc
,
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
):
def
from_prim_func
(
cls
,
func
:
tvm
.
tir
.
PrimFunc
,
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
):
return
cls
(
arch
,
tags
).
_init_with_prim_func
(
func
,
name
)
@
classmethod
def
from_output_nodes
(
cls
,
nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
):
return
cls
(
arch
,
tags
).
_init_with_output_nodes
(
nodes
)
def
_init_with_prim_func
(
self
,
func
:
tvm
.
tir
.
PrimFunc
,
name
:
str
=
"PrimFuncNode"
)
->
DefaultPolicy
:
def
_init_with_prim_func
(
self
,
func
:
tvm
.
tir
.
PrimFunc
,
name
:
str
=
"PrimFuncNode"
)
->
DefaultPolicy
:
if
func
is
not
None
and
isinstance
(
func
,
tvm
.
tir
.
PrimFunc
):
self
.
func
=
func
self
.
prim_func_node
=
PrimFuncNode
(
self
.
func
,
tags
=
self
.
tags
,
name
=
name
)
...
...
@@ -60,9 +55,7 @@ class DefaultPolicy:
return
self
def
_init_with_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
]):
self
.
ordered_nodes
=
list
(
filter
(
lambda
n
:
not
n
.
is_placeholder
()
and
not
n
.
is_output
(),
find_topo_sort
(
output_nodes
)))
self
.
ordered_nodes
=
list
(
filter
(
lambda
n
:
not
n
.
is_placeholder
()
and
not
n
.
is_output
(),
find_topo_sort
(
output_nodes
)))
for
node
in
self
.
ordered_nodes
:
node
.
update_tags
(
self
.
tags
)
...
...
@@ -102,13 +95,14 @@ class DefaultPolicy:
def
dfs_smem_tile
(
self
,
init_tile
,
rstep_map
)
->
Iterable
[
TileDict
]:
_steps
=
[
get_all_factors
(
n
)
for
n
in
self
.
output_nodes
[
0
].
get_space_dim
()]
steps
=
[
step
[
step
.
index
(
t
):]
for
step
,
t
in
zip
(
_steps
,
init_tile
)]
steps
=
[
step
[
step
.
index
(
t
)
:]
for
step
,
t
in
zip
(
_steps
,
init_tile
)]
for
i
in
range
(
len
(
steps
)):
added
=
list
(
filter
(
lambda
s
:
s
<
steps
[
i
][
-
1
]
and
s
>
steps
[
i
][
0
]
and
s
not
in
steps
[
i
],
[
2
,
4
,
8
,
16
,
32
],
))
)
)
steps
[
i
].
extend
(
added
)
steps
[
i
]
=
sorted
(
steps
[
i
])
visited_tiles
=
{}
...
...
@@ -190,10 +184,7 @@ class DefaultPolicy:
"""
tile_map
=
{}
for
node
in
self
.
output_nodes
:
tile_map
[
node
]
=
[
tile
[
i
]
*
node
.
get_space_dim
()[
i
]
//
self
.
output_nodes
[
0
].
get_space_dim
()[
i
]
for
i
in
range
(
len
(
tile
))
]
tile_map
[
node
]
=
[
tile
[
i
]
*
node
.
get_space_dim
()[
i
]
//
self
.
output_nodes
[
0
].
get_space_dim
()[
i
]
for
i
in
range
(
len
(
tile
))]
return
tile_map
def
compute_workload_per_item
(
self
,
output_tile
)
->
float
:
...
...
@@ -304,8 +295,7 @@ class DefaultPolicy:
score
=
0
shape
=
node
.
propagate_inputs
(
tile
,
rstep
=
rstep
)
for
i
,
input_buffer
in
enumerate
(
node
.
input_buffers
):
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
(
(
node
.
get_buffer_dtype
(
input_buffer
).
bits
+
7
)
//
8
)
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
((
node
.
get_buffer_dtype
(
input_buffer
).
bits
+
7
)
//
8
)
score
+=
sim
(
int
(
coalesced_factor
(
shape
[
i
],
input_buffer
.
shape
)),
read_transaction_elements
,
...
...
@@ -380,17 +370,13 @@ class DefaultPolicy:
return
None
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
new_rstep_map
=
rstep_map
.
copy
()
while
True
:
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
if
new_rstep_id
is
None
:
break
new_rstep_map
[
node
]
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
new_rstep_map
[
node
]
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
old_rstep_map
=
td
.
rstep_map
td
.
rstep_map
=
new_rstep_map
smem_usage
,
_
=
self
.
_compute_shared_memory_usage
(
td
)
...
...
@@ -434,15 +420,14 @@ class DefaultPolicy:
if
edge
.
src_node
.
is_placeholder
():
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
nbytes
traffic
+=
coalesced_tensor_shape
(
input_shapes
[
i
],
edge
.
src_node
.
get_shape
(),
read_transaction_elements
)
*
nbytes
traffic
+=
coalesced_tensor_shape
(
input_shapes
[
i
],
edge
.
src_node
.
get_shape
(),
read_transaction_elements
)
*
nbytes
for
edge
in
node
.
outputs
:
if
edge
.
dst_node
.
is_output
():
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
write_transaction_elements
=
self
.
arch
.
transaction_size
[
0
]
//
nbytes
traffic
+=
coalesced_tensor_shape
(
output_shapes
[
edge
.
src_id
],
node
.
get_shape
(
edge
.
src_id
),
write_transaction_elements
)
*
nbytes
traffic
+=
(
coalesced_tensor_shape
(
output_shapes
[
edge
.
src_id
],
node
.
get_shape
(
edge
.
src_id
),
write_transaction_elements
)
*
nbytes
)
return
traffic
,
op_tile_map
...
...
@@ -487,10 +472,7 @@ class DefaultPolicy:
cached_tensors_map
=
{}
def
can_free
(
node
,
out_id
):
for
edge
in
node
.
outputs
:
if
edge
.
src_id
==
out_id
and
edge
.
dst_node
not
in
processed
:
return
False
return
True
return
all
(
not
(
edge
.
src_id
==
out_id
and
edge
.
dst_node
not
in
processed
)
for
edge
in
node
.
outputs
)
for
node
in
self
.
ordered_nodes
:
node_internal_bytes
,
cached_tensors_map
[
node
]
=
self
.
infer_node_smem_usage
(
td
,
node
)
...
...
@@ -528,9 +510,7 @@ class DefaultPolicy:
Tuple[Dict, Dict]
A tuple of dictionaries containing the output strides and tensor strides.
"""
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)
}
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)}
tensor_strides
=
{}
return
output_strides
,
tensor_strides
...
...
@@ -551,8 +531,7 @@ class DefaultPolicy:
output_strides_map
=
{}
tensor_strides_map
=
{}
for
node
in
self
.
ordered_nodes
:
output_strides_map
[
node
],
tensor_strides_map
[
node
]
=
self
.
compute_node_stride_map
(
node
,
td
)
output_strides_map
[
node
],
tensor_strides_map
[
node
]
=
self
.
compute_node_stride_map
(
node
,
td
)
td
.
output_strides_map
,
td
.
tensor_strides_map
=
output_strides_map
,
tensor_strides_map
def
compute_tile_dict
(
self
,
output_tile
:
list
[
int
],
rstep_map
)
->
TileDict
:
...
...
@@ -582,9 +561,7 @@ class DefaultPolicy:
output_shape
=
self
.
output_nodes
[
0
].
get_space_dim
()
td
.
grid_size
=
int
(
np
.
prod
([(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
output_tile
,
output_shape
)]))
# estimated reg usage
reg_usage
=
int
(
2
*
max
([
np
.
prod
(
td
.
get_tile
(
node
))
*
node
.
get_dtype
().
bits
/
32
for
node
in
self
.
ordered_nodes
]))
reg_usage
=
int
(
2
*
max
([
np
.
prod
(
td
.
get_tile
(
node
))
*
node
.
get_dtype
().
bits
/
32
for
node
in
self
.
ordered_nodes
]))
if
reg_usage
>
self
.
arch
.
reg_cap
:
td
.
valid
=
False
return
td
...
...
@@ -609,13 +586,10 @@ class DefaultPolicy:
for
node
in
self
.
ordered_nodes
:
if
np
.
prod
(
td
.
get_tile
(
node
))
==
0
:
return
False
node_grid_size
=
np
.
prod
([
(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
td
.
get_tile
(
node
),
node
.
get_space_dim
())
])
node_grid_size
=
np
.
prod
([(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
td
.
get_tile
(
node
),
node
.
get_space_dim
())])
if
node_grid_size
!=
td
.
grid_size
:
return
False
if
(
hasattr
(
node
,
"reduce_op"
)
and
node
.
reduce_op
is
not
None
and
len
(
node
.
reduce_op
.
axis
)
==
len
(
td
.
output_tile
)):
if
hasattr
(
node
,
"reduce_op"
)
and
node
.
reduce_op
is
not
None
and
len
(
node
.
reduce_op
.
axis
)
==
len
(
td
.
output_tile
):
for
i
,
tile_extent
in
enumerate
(
td
.
output_tile
):
if
node
.
reduce_op
.
axis
[
i
].
dom
.
extent
%
tile_extent
:
return
False
...
...
@@ -639,23 +613,22 @@ class DefaultPolicy:
node_space_sizes
=
[
int
(
np
.
prod
(
td
.
get_tile
(
node
)))
for
node
in
self
.
ordered_nodes
]
max_block_size
=
functools
.
reduce
(
math
.
gcd
,
node_space_sizes
)
if
max_block_size
<
self
.
arch
.
warp_size
*
self
.
arch
.
sm_partition
and
max_block_size
==
min
(
node_space_sizes
):
node_reduce_sizes
=
[
int
(
np
.
prod
(
list
(
td
.
get_rstep
(
node
).
values
())))
for
node
in
self
.
ordered_nodes
]
if
max_block_size
<
self
.
arch
.
warp_size
*
self
.
arch
.
sm_partition
and
max_block_size
==
min
(
node_space_sizes
):
node_reduce_sizes
=
[
int
(
np
.
prod
(
list
(
td
.
get_rstep
(
node
).
values
())))
for
node
in
self
.
ordered_nodes
]
total_sizes
=
[
x
*
y
for
x
,
y
in
zip
(
node_space_sizes
,
node_reduce_sizes
)]
max_possible_size
=
functools
.
reduce
(
math
.
gcd
,
total_sizes
)
possible_block_sizes
=
list
(
filter
(
lambda
x
:
x
%
max_block_size
==
0
and
x
<=
1024
,
get_all_factors
(
max_possible_size
),
))
)
)
possible_block_sizes
=
list
(
filter
(
# either be a factor of space or cover fully cover the space
lambda
x
:
all
([
x
%
s
==
0
or
s
%
x
==
0
for
s
in
node_space_sizes
]),
possible_block_sizes
,
))
)
)
factor_ordered
=
sorted
(
possible_block_sizes
,
key
=
self
.
score_block_size
)
return
factor_ordered
else
:
...
...
@@ -821,8 +794,7 @@ class DefaultPolicy:
vectorize_result
=
{}
for
tensor
,
shape
in
shapes
.
items
():
for
v
in
vectorize_sizes
:
if
(
is_shape_aligned
(
shape
,
block_size
*
v
)
and
is_cont
(
shape
,
v
)
and
is_type_allowed
(
dtypes
[
tensor
],
v
)):
if
is_shape_aligned
(
shape
,
block_size
*
v
)
and
is_cont
(
shape
,
v
)
and
is_type_allowed
(
dtypes
[
tensor
],
v
):
vectorize_result
[
tensor
]
=
v
break
return
vectorize_result
...
...
tilelang/carver/roller/policy/tensorcore.py
View file @
29051439
"""Policy for tensorcore schedule"""
from
__future__
import
annotations
import
tvm
import
numpy
as
np
...
...
@@ -13,7 +14,6 @@ logger = logging.getLogger(__name__)
class
TensorCorePolicy
(
DefaultPolicy
):
# this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32.
wmma_k
:
int
=
16
...
...
@@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy):
A_high_ax
=
min
(
A_ax_m
,
A_ax_k
)
B_high_ax
=
min
(
B_ax_n
,
B_ax_k
)
C_high_ax
=
min
(
C_ax_m
,
C_ax_n
)
A_stride
=
Stride
(
stride
=
np
.
prod
(
AS_shape
[
A_high_ax
+
1
:])
+
offset
,
ax
=
A_high_ax
)
B_stride
=
Stride
(
stride
=
np
.
prod
(
BS_shape
[
B_high_ax
+
1
:])
+
offset
,
ax
=
B_high_ax
)
C_stride
=
Stride
(
stride
=
np
.
prod
(
CS_shape
[
C_high_ax
+
1
:])
+
offset
,
ax
=
C_high_ax
)
A_stride
=
Stride
(
stride
=
np
.
prod
(
AS_shape
[
A_high_ax
+
1
:])
+
offset
,
ax
=
A_high_ax
)
B_stride
=
Stride
(
stride
=
np
.
prod
(
BS_shape
[
B_high_ax
+
1
:])
+
offset
,
ax
=
B_high_ax
)
C_stride
=
Stride
(
stride
=
np
.
prod
(
CS_shape
[
C_high_ax
+
1
:])
+
offset
,
ax
=
C_high_ax
)
return
A_stride
,
B_stride
,
C_stride
def
infer_node_smem_usage
(
self
,
td
:
TileDict
,
node
:
PrimFuncNode
):
...
...
@@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy):
# get reduce input size
target_transaction
=
self
.
arch
.
transaction_size
[
0
]
*
2
# 512 bytes // type bits
reduce_input_dtype
=
node
.
get_buffer_dtype
(
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)[
0
])
reduce_input_dtype
=
node
.
get_buffer_dtype
(
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)[
0
])
basic
=
(
target_transaction
*
8
)
//
reduce_input_dtype
.
bits
result
=
{}
...
...
@@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy):
iter_name
=
iter_info
.
var
.
name
iter_dom
=
iter_info
.
dom
.
extent
if
iter_dom
%
16
>
0
:
result
[
iter_name
]
=
(
16
if
iter_dom
<
basic
else
basic
)
# for the case of padding
result
[
iter_name
]
=
16
if
iter_dom
<
basic
else
basic
# for the case of padding
elif
iter_dom
%
basic
==
0
:
result
[
iter_name
]
=
basic
else
:
...
...
@@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy):
return
False
if
_check_small_tile
(
td
):
smem_limit
=
min
(
self
.
arch
.
max_smem_usage
//
td
.
block_per_SM
,
self
.
arch
.
smem_cap
)
rstep_map
=
td
.
rstep_map
.
copy
()
...
...
@@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy):
return
rstep
def
_shared_memory_usage
(
td
:
TileDict
):
return
node
.
footprint
(
td
.
output_tile
,
new_rstep_map
,
td
.
tensor_strides_map
[
node
])
return
node
.
footprint
(
td
.
output_tile
,
new_rstep_map
,
td
.
tensor_strides_map
[
node
])
def
_score
(
rstep_id
):
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
score
=
0
shape
=
node
.
propagate_inputs_on_reduction
(
td
.
get_tile
(
node
),
rstep
=
rstep
)
input_buffers
=
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)
...
...
@@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy):
return
None
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
new_rstep_map
=
rstep_map
.
copy
()
while
True
:
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
if
new_rstep_id
is
None
:
break
new_rstep_map
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
new_rstep_map
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
old_rstep_map
=
td
.
rstep_map
td
.
rstep_map
=
new_rstep_map
smem_usage
,
_
=
_shared_memory_usage
(
td
)
...
...
@@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy):
break
else
:
cur_rstep_id
=
new_rstep_id
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
cur_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
cur_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
return
rstep
for
node
in
self
.
ordered_nodes
:
...
...
@@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy):
return
super
().
get_node_reduce_step_candidates
(
node
)
else
:
# must be a a multiple of wmma_k
return
{
k
.
var
.
name
:
[
x
*
self
.
wmma_k
for
x
in
get_all_factors
(
int
(
k
.
dom
.
extent
)
//
self
.
wmma_k
)
]
for
k
in
node
.
raxis
}
return
{
k
.
var
.
name
:
[
x
*
self
.
wmma_k
for
x
in
get_all_factors
(
int
(
k
.
dom
.
extent
)
//
self
.
wmma_k
)]
for
k
in
node
.
raxis
}
def
check_tile_shape_isvalid
(
self
,
td
:
TileDict
):
for
node
in
self
.
ordered_nodes
:
...
...
@@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy):
td
.
tile_map
[
node
][
ax_n
],
)
# check the tile size is valid
wmma_invalid
=
[
block_m
<
wmma_m
or
block_n
<
wmma_n
for
wmma_m
,
wmma_n
in
self
.
arch
.
get_avaliable_tensorintrin_shapes
()
]
wmma_invalid
=
[
block_m
<
wmma_m
or
block_n
<
wmma_n
for
wmma_m
,
wmma_n
in
self
.
arch
.
get_avaliable_tensorintrin_shapes
()]
if
all
(
wmma_invalid
):
return
False
if
any
([
y
%
x
for
x
,
y
in
zip
(
td
.
tile_map
[
node
],
node
.
get_space_dim
())]):
...
...
@@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy):
return
super
().
compute_node_stride_map
(
node
,
td
)
use_layout
=
self
.
_can_implement_layout
(
node
,
td
)
AS_stride
,
BS_stride
,
C_stride
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
),
td
.
get_rstep
(
node
))
AS_stride
,
BS_stride
,
C_stride
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
),
td
.
get_rstep
(
node
))
A_stride
,
B_stride
,
_
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
))
tensor_strides
=
{}
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)
}
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)}
tensor_strides
=
{}
# when connected to shared input, should use full stride without rstep
for
i
,
(
_
,
_
)
in
enumerate
(
zip
([
AS_stride
,
BS_stride
],
[
A_stride
,
B_stride
])):
...
...
@@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy):
overall_gmem_size_in_bytes
:
int
=
0
for
node
in
self
.
ordered_nodes
:
for
buffer
in
node
.
input_buffers
:
overall_gmem_size_in_bytes
+=
(
int
(
np
.
prod
(
buffer
.
shape
))
*
tvm
.
DataType
(
buffer
.
dtype
).
bits
//
8
)
overall_gmem_size_in_bytes
+=
int
(
np
.
prod
(
buffer
.
shape
))
*
tvm
.
DataType
(
buffer
.
dtype
).
bits
//
8
return
overall_gmem_size_in_bytes
<
self
.
arch
.
l2_cache_size_bytes
conditions
.
append
(
_check_memory_size
())
...
...
tilelang/carver/roller/rasterization.py
View file @
29051439
...
...
@@ -2,7 +2,6 @@
class
Rasterization
:
panel_width_
=
None
def
__init__
(
self
)
->
None
:
...
...
@@ -18,7 +17,6 @@ class Rasterization:
class
NoRasterization
(
Rasterization
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
tilelang/carver/roller/shape_inference/common.py
View file @
29051439
...
...
@@ -4,9 +4,7 @@ from tvm import arith
class
Statement
:
def
__init__
(
self
,
output
:
str
,
dependent_region
:
dict
,
var_map
:
OrderedDict
,
range_map
:
OrderedDict
):
def
__init__
(
self
,
output
:
str
,
dependent_region
:
dict
,
var_map
:
OrderedDict
,
range_map
:
OrderedDict
):
self
.
output
=
output
self
.
dependent_region
=
dependent_region
self
.
var_map
=
var_map
...
...
@@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
class
InputShapeInference
:
def
__init__
(
self
,
deps
:
list
[
Statement
]):
self
.
deps
=
deps
...
...
tilelang/carver/roller/shape_inference/tir.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ from tvm import arith, tir
class
Statement
:
def
__init__
(
self
,
block_analyzer
,
block
:
BlockRV
):
self
.
block_analyzer
=
block_analyzer
self
.
block
=
block
...
...
@@ -21,9 +20,7 @@ class Statement:
if
len
(
self
.
dependent_region
[
input_name
])
!=
1
:
return
None
indices
=
self
.
dependent_region
[
input_name
][
0
]
iter_map_range
=
{
_iter
.
var
:
_iter
.
dom
for
_iter
in
self
.
block_analyzer
.
get_spatial_axis
(
self
.
block
)
}
iter_map_range
=
{
_iter
.
var
:
_iter
.
dom
for
_iter
in
self
.
block_analyzer
.
get_spatial_axis
(
self
.
block
)}
iter_map_result
=
arith
.
detect_iter_map
(
indices
,
iter_map_range
,
...
...
@@ -77,7 +74,6 @@ class TensorDepNode:
class
DependencyAnalysis
:
def
__init__
(
self
,
deps
):
self
.
deps
=
deps
# issue: duplicate name when we have two same ops.
...
...
@@ -112,8 +108,7 @@ class DependencyAnalysis:
def
traverse_dependencies
(
self
,
compute
):
if
isinstance
(
compute
,
Statement
):
node
=
self
.
get_or_create_node
(
compute
.
block_analyzer
.
get_output_buffers
(
compute
.
block
)[
0
].
name
)
node
=
self
.
get_or_create_node
(
compute
.
block_analyzer
.
get_output_buffers
(
compute
.
block
)[
0
].
name
)
# Loop through input tensors
for
input_buffer
in
compute
.
block_analyzer
.
get_input_buffers
(
compute
.
block
):
# Get the input node
...
...
@@ -167,7 +162,6 @@ class DependencyAnalysis:
class
InputShapeInference
:
def
__init__
(
self
,
deps
:
list
[
Statement
]):
self
.
deps
=
deps
self
.
target_mapping
=
{}
...
...
@@ -183,16 +177,11 @@ class InputShapeInference:
if
targets
in
self
.
target_mapping
:
return
self
.
target_mapping
[
targets
]
# should be buffer name instead of block name
name2dep
=
{
dep
.
block_analyzer
.
get_output_buffers
(
dep
.
block
)[
0
].
name
:
dep
for
dep
in
self
.
deps
}
name2dep
=
{
dep
.
block_analyzer
.
get_output_buffers
(
dep
.
block
)[
0
].
name
:
dep
for
dep
in
self
.
deps
}
mapping
=
{}
input_vars
=
[]
for
target
in
targets
:
vars
=
[
iter
.
var
for
iter
in
name2dep
[
target
].
block_analyzer
.
get_spatial_axis
(
name2dep
[
target
].
block
)
]
vars
=
[
iter
.
var
for
iter
in
name2dep
[
target
].
block_analyzer
.
get_spatial_axis
(
name2dep
[
target
].
block
)]
input_vars
.
append
(
vars
)
mapping
[
target
]
=
[
vars
]
ana
=
arith
.
Analyzer
()
...
...
@@ -221,13 +210,8 @@ class InputShapeInference:
mapping
[
input_name
]
=
[]
for
indices
in
indices_list
:
for
region
in
regions
:
vmap
=
{
k
:
(
tir
.
Cast
(
k
.
dtype
,
v
)
if
v
.
dtype
!=
k
.
dtype
else
v
)
for
k
,
v
in
zip
(
ax_vars
,
indices
)
}
region
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
ax
,
vmap
))
for
ax
in
region
]
vmap
=
{
k
:
(
tir
.
Cast
(
k
.
dtype
,
v
)
if
v
.
dtype
!=
k
.
dtype
else
v
)
for
k
,
v
in
zip
(
ax_vars
,
indices
)}
region
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
ax
,
vmap
))
for
ax
in
region
]
if
not
region_exist_in_list
(
region
,
mapping
[
input_name
]):
mapping
[
input_name
].
append
(
region
)
buffers
=
[]
...
...
@@ -241,10 +225,7 @@ class InputShapeInference:
self
.
target_mapping
[
targets
]
=
input_vars
,
mapping
return
input_vars
,
mapping
def
infer
(
self
,
shape
:
dict
[
str
,
list
[
arith
.
ConstIntBound
]],
rstep
:
dict
[
str
,
int
]
=
None
,
targets
=
None
):
def
infer
(
self
,
shape
:
dict
[
str
,
list
[
arith
.
ConstIntBound
]],
rstep
:
dict
[
str
,
int
]
=
None
,
targets
=
None
):
if
rstep
is
None
:
rstep
=
{}
compute_targets
=
tuple
(
shape
.
keys
())
...
...
@@ -258,8 +239,7 @@ class InputShapeInference:
for
ax
in
self
.
reduce_axes
:
# assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value.
if
ax
.
var
.
name
in
rstep
:
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
min
(
ax
.
dom
.
extent
,
rstep
[
ax
.
var
.
name
])
-
1
))
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
min
(
ax
.
dom
.
extent
,
rstep
[
ax
.
var
.
name
])
-
1
))
else
:
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
ax
.
dom
.
extent
-
1
))
ana
.
update
(
ax
.
var
,
bound
,
True
)
...
...
@@ -312,14 +292,11 @@ class InputShapeInference:
for
name
,
regions
in
mapping
.
items
():
region
=
regions
[
0
]
result
[
name
]
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
index
,
vmap
))
for
index
in
region
]
result
[
name
]
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
index
,
vmap
))
for
index
in
region
]
return
result
def
region_exist_in_list
(
a
,
list
)
->
bool
:
def
expr_is_same
(
a
,
b
)
->
bool
:
if
isinstance
(
a
,
tir
.
IntImm
)
and
isinstance
(
b
,
tir
.
IntImm
):
return
a
.
value
==
b
.
value
...
...
tilelang/carver/template/base.py
View file @
29051439
...
...
@@ -2,7 +2,12 @@
from
abc
import
ABC
,
abstractmethod
# For defining abstract base classes
from
dataclasses
import
dataclass
,
field
# For defining data classes
from
..arch
import
(
# Import architecture-related utilities and classes
TileDevice
,
is_volta_arch
,
is_ampere_arch
,
is_cdna_arch
,
auto_infer_current_arch
)
TileDevice
,
is_volta_arch
,
is_ampere_arch
,
is_cdna_arch
,
auto_infer_current_arch
,
)
from
..roller.hint
import
Hint
# Import the Hint class
from
..roller.node
import
OutputNode
# Import the OutputNode class
from
tvm.tir
import
PrimFunc
# Import PrimFunc for handling tensor IR functions
...
...
@@ -41,7 +46,7 @@ class BaseTemplate(ABC):
"""
pass
def
with_arch
(
self
,
arch
:
TileDevice
)
->
'
BaseTemplate
'
:
def
with_arch
(
self
,
arch
:
TileDevice
)
->
"
BaseTemplate
"
:
"""
Sets the architecture for this template and returns itself.
...
...
@@ -109,7 +114,7 @@ class BaseTemplate(ABC):
"""
raise
NotImplementedError
(
"initialize_function is not implemented"
)
def
set_function
(
self
,
func
:
PrimFunc
)
->
'
BaseTemplate
'
:
def
set_function
(
self
,
func
:
PrimFunc
)
->
"
BaseTemplate
"
:
"""
Sets the function for this template and returns itself.
...
...
@@ -122,7 +127,7 @@ class BaseTemplate(ABC):
self
.
_func
=
func
return
self
def
set_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
])
->
'
BaseTemplate
'
:
def
set_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
])
->
"
BaseTemplate
"
:
"""
Sets the output nodes for this template and returns itself.
...
...
tilelang/carver/template/conv.py
View file @
29051439
...
...
@@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype (str): Data type used for accumulation.
with_bias (bool): Whether to add a bias term.
"""
# Operation-related configuration parameters
N
:
int
# The number of input samples processed simultaneously in a batch.
C
:
int
# The number of input feature maps.
...
...
@@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate):
AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers.
"""
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
self
.
N
,
self
.
C
,
self
.
H
,
self
.
W
,
self
.
F
,
self
.
K
,
self
.
S
,
self
.
D
,
self
.
P
assert
(
isinstance
(
N
,
int
)
and
isinstance
(
C
,
int
)
and
isinstance
(
H
,
int
)
and
isinstance
(
W
,
int
)
and
isinstance
(
F
,
int
)
and
isinstance
(
K
,
int
)
and
isinstance
(
S
,
int
)
and
isinstance
(
D
,
int
)
and
isinstance
(
P
,
int
)),
"Only Support Integer Params"
assert
(
N
>
0
and
C
>
0
and
H
>
0
and
W
>
0
and
F
>
0
and
K
>
0
and
S
>
0
and
D
>
0
and
P
>
0
),
"Params should be positive"
assert
(
isinstance
(
N
,
int
)
and
isinstance
(
C
,
int
)
and
isinstance
(
H
,
int
)
and
isinstance
(
W
,
int
)
and
isinstance
(
F
,
int
)
and
isinstance
(
K
,
int
)
and
isinstance
(
S
,
int
)
and
isinstance
(
D
,
int
)
and
isinstance
(
P
,
int
)
),
"Only Support Integer Params"
assert
N
>
0
and
C
>
0
and
H
>
0
and
W
>
0
and
F
>
0
and
K
>
0
and
S
>
0
and
D
>
0
and
P
>
0
,
"Params should be positive"
# Load configuration parameters
in_dtype
,
out_dtype
,
accum_dtype
=
self
.
in_dtype
,
self
.
out_dtype
,
self
.
accum_dtype
...
...
@@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate):
te
.
if_then_else
(
te
.
all
(
h_in
>=
0
,
h_in
<
H
,
w_in
>=
0
,
w_in
<
W
),
A
[
n
,
h_in
,
w_in
,
c
].
astype
(
accum_dtype
)
*
B
[
kh
,
kw
,
c
,
f
].
astype
(
accum_dtype
),
tir
.
const
(
0
,
accum_dtype
)),
axis
=
[
kh
,
kw
,
c
])
tir
.
const
(
0
,
accum_dtype
),
),
axis
=
[
kh
,
kw
,
c
],
)
# Compute convolution result
C
=
te
.
compute
(
...
...
tilelang/carver/template/flashattention.py
View file @
29051439
...
...
@@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_
@
dataclass
class
FlashAttentionTemplate
(
BaseTemplate
):
_output_nodes
:
list
[
OutputNode
]
=
None
# Operation-related configuration parameters
...
...
@@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate):
"""
A_indices
=
[
b
,
i
,
k
]
B_indices
=
[
b
,
j
,
k
]
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
C
=
te
.
compute
(
...
...
tilelang/carver/template/gemv.py
View file @
29051439
...
...
@@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate):
N
,
K
=
self
.
N
,
self
.
K
# Ensure M, N, K are valid positive integers
assert
(
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
)),
"Only Support Integer M, N, K"
assert
(
M
>
0
and
N
>
0
and
K
>
0
),
"M, N, K should be positive"
assert
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Only Support Integer M, N, K"
assert
M
>
0
and
N
>
0
and
K
>
0
,
"M, N, K should be positive"
# Load configuration parameters
trans_B
=
self
.
trans_B
...
...
@@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate):
"""
A_indices
=
[
i
,
k
]
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
C
=
te
.
compute
(
...
...
tilelang/carver/template/general_reduce.py
View file @
29051439
...
...
@@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func
@
dataclass
class
GeneralReductionTemplate
(
BaseTemplate
):
# OP Related Config
structure
:
str
|
list
[
str
]
=
None
shape
:
list
[
int
]
=
None
dtype
:
str
=
"float16"
def
get_hardware_aware_configs
(
self
,
arch
:
TileDevice
=
None
,
topk
:
int
=
10
)
->
list
[
Hint
]:
roller_hints
=
get_roller_hints_from_func
(
self
.
_func
,
arch
=
arch
,
topk
=
topk
,
allow_gemv
=
False
)
roller_hints
=
get_roller_hints_from_func
(
self
.
_func
,
arch
=
arch
,
topk
=
topk
,
allow_gemv
=
False
)
return
roller_hints
def
initialize_function
(
self
)
->
None
:
...
...
@@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate):
spatial_axes
=
[]
reduce_axes
=
[]
for
i
,
axis_type
in
enumerate
(
self
.
structure
):
if
axis_type
.
upper
()
==
'S'
:
if
axis_type
.
upper
()
==
"S"
:
spatial_axes
.
append
((
i
,
self
.
shape
[
i
]))
elif
axis_type
.
upper
()
==
'R'
:
elif
axis_type
.
upper
()
==
"R"
:
reduce_axes
.
append
((
i
,
self
.
shape
[
i
]))
else
:
raise
ValueError
(
f
"Unrecognized axis type '
{
axis_type
}
', only 'S'/'R' allowed."
)
...
...
@@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate):
# Walk through the structure in order
for
axis_type
in
self
.
structure
:
if
axis_type
.
upper
()
==
'S'
:
if
axis_type
.
upper
()
==
"S"
:
# use the next spatial_indices item
full_index
.
append
(
spatial_indices
[
spatial_iter
])
spatial_iter
+=
1
...
...
tilelang/carver/template/matmul.py
View file @
29051439
...
...
@@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate):
M
,
N
,
K
=
self
.
M
,
self
.
N
,
self
.
K
# Ensure M, N, K are valid positive integers
assert
(
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
)),
"Only Support Integer M, N, K"
assert
(
M
>
0
and
N
>
0
and
K
>
0
),
"M, N, K should be positive"
assert
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Only Support Integer M, N, K"
assert
M
>
0
and
N
>
0
and
K
>
0
,
"M, N, K should be positive"
# Load configuration parameters
trans_A
,
trans_B
=
self
.
trans_A
,
self
.
trans_B
...
...
@@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate):
"""
A_indices
=
[
i
,
k
]
if
not
trans_A
else
[
k
,
i
]
# Adjust indexing if A is transposed
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
# Adjust indexing if B is transposed
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
C
=
te
.
compute
(
...
...
tilelang/carver/utils.py
View file @
29051439
...
...
@@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
"""
def
get_roller_hints_from_func
(
func_or_module
:
tir
.
PrimFunc
|
IRModule
,
arch
:
TileDevice
,
topk
:
int
=
10
,
tensorcore_only
:
bool
=
False
,
allow_gemv
:
bool
=
False
)
->
list
[
Hint
]
|
None
:
def
get_roller_hints_from_func
(
func_or_module
:
tir
.
PrimFunc
|
IRModule
,
arch
:
TileDevice
,
topk
:
int
=
10
,
tensorcore_only
:
bool
=
False
,
allow_gemv
:
bool
=
False
)
->
list
[
Hint
]
|
None
:
func
=
None
if
isinstance
(
func_or_module
,
tir
.
PrimFunc
):
func
=
func_or_module
...
...
@@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
roller_hints
=
None
if
tensorcore_only
:
try
:
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
except
Exception
as
e_msg
:
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
tags
=
None
...
...
@@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
policy
=
DefaultPolicy
.
from_prim_func
(
func
=
func
,
arch
=
arch
)
tensorized_func
=
None
try
:
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
except
Exception
as
e_msg
:
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
tags
=
None
...
...
@@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
return
roller_hints
def
get_roller_hints_from_output_nodes
(
output_nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
topk
:
int
=
10
,
extra_tags
:
list
[
str
]
|
None
=
None
)
->
list
[
Hint
]
|
None
:
def
get_roller_hints_from_output_nodes
(
output_nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
topk
:
int
=
10
,
extra_tags
:
list
[
str
]
|
None
=
None
)
->
list
[
Hint
]
|
None
:
assert
isinstance
(
output_nodes
,
list
),
"The input should be a list of functions."
lints
=
[]
...
...
@@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
policy
=
TensorCorePolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
lints
=
policy
.
emit_config
(
topk
)
except
Exception
as
e_msg
:
logger
.
debug
(
f
"Generate hints from output nodes failed:
{
e_msg
}
"
,
"fallback to default policy"
)
logger
.
debug
(
f
"Generate hints from output nodes failed:
{
e_msg
}
"
,
"fallback to default policy"
)
if
len
(
lints
)
==
0
:
policy
=
DefaultPolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
...
...
@@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
def
retrieve_func_from_module
(
ir_module
:
IRModule
)
->
PrimFunc
:
if
not
isinstance
(
ir_module
,
IRModule
):
raise
ValueError
(
"Not supported type: "
,
type
(
ir_module
))
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
(
"The optimized module should only have one global variable for default schedule."
)
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
"The optimized module should only have one global variable for default schedule."
func
=
list
(
ir_module
.
functions
.
values
())[
0
]
return
func
tilelang/contrib/cc.py
View file @
29051439
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Util to invoke C/C++ compilers in the system."""
import
functools
import
os
import
shutil
...
...
@@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils
def
_is_linux_like
():
return
(
sys
.
platform
==
"darwin"
or
sys
.
platform
.
startswith
(
"linux"
)
or
sys
.
platform
.
startswith
(
"freebsd"
))
return
sys
.
platform
==
"darwin"
or
sys
.
platform
.
startswith
(
"linux"
)
or
sys
.
platform
.
startswith
(
"freebsd"
)
def
_is_windows_like
():
...
...
@@ -90,7 +90,7 @@ def get_cplus_compiler():
def
is_darwin
():
return
platform
.
system
()
==
'
Darwin
'
return
platform
.
system
()
==
"
Darwin
"
def
create_shared
(
output
,
objects
,
options
=
None
,
cc
=
None
,
cwd
=
None
,
ccache_env
=
None
):
...
...
@@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared
.
get_target_triple
=
get_target_by_dump_machine
(
os
.
environ
.
get
(
"CXX"
,
get_cc
()))
def
cross_compiler
(
compile_func
,
options
=
None
,
output_format
=
None
,
get_target_triple
=
None
,
add_files
=
None
):
def
cross_compiler
(
compile_func
,
options
=
None
,
output_format
=
None
,
get_target_triple
=
None
,
add_files
=
None
):
"""Create a cross compiler function by specializing compile_func with options.
This function can be used to construct compile functions that
...
...
@@ -363,13 +359,7 @@ def cross_compiler(compile_func,
return
_fcompile
def
_linux_compile
(
output
,
objects
,
options
,
compile_cmd
,
cwd
=
None
,
ccache_env
=
None
,
compile_shared
=
False
):
def
_linux_compile
(
output
,
objects
,
options
,
compile_cmd
,
cwd
=
None
,
ccache_env
=
None
,
compile_shared
=
False
):
cmd
=
[
compile_cmd
]
if
compile_cmd
!=
"nvcc"
:
if
compile_shared
or
output
.
endswith
(
".so"
)
or
output
.
endswith
(
".dylib"
):
...
...
@@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
raise
ValueError
(
"ccache not found"
)
try
:
proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
cwd
=
cwd
,
env
=
env
)
proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
cwd
=
cwd
,
env
=
env
)
(
out
,
_
)
=
proc
.
communicate
()
except
FileNotFoundError
:
raise
RuntimeError
(
"Can not find the LLVM clang for Windows clang.exe)."
"Make sure it's installed"
" and the installation directory is in the %PATH% environment "
"variable. Prebuilt binaries can be found at: https://llvm.org/"
)
\
from
None
raise
RuntimeError
(
"Can not find the LLVM clang for Windows clang.exe)."
"Make sure it's installed"
" and the installation directory is in the %PATH% environment "
"variable. Prebuilt binaries can be found at: https://llvm.org/"
)
from
None
if
proc
.
returncode
!=
0
:
msg
=
"Compilation error:
\n
"
msg
+=
" "
.
join
(
cmd
)
+
"
\n
"
...
...
tilelang/contrib/dlpack.py
View file @
29051439
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from
tvm
import
runtime
...
...
@@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
def
adapt_tensor
(
arg
):
if
isinstance
(
arg
,
tensor_type
):
if
arg
.
dtype
in
{
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
}:
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
if
arg
.
dtype
in
{
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
}:
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
))
return
arg
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment