Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
467
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
176 additions
and
316 deletions
+176
-316
tilelang/carver/arch/metal.py
tilelang/carver/arch/metal.py
+2
-3
tilelang/carver/common_schedules.py
tilelang/carver/common_schedules.py
+1
-0
tilelang/carver/matmul_analysis.py
tilelang/carver/matmul_analysis.py
+29
-50
tilelang/carver/roller/bestfit.py
tilelang/carver/roller/bestfit.py
+2
-6
tilelang/carver/roller/hint.py
tilelang/carver/roller/hint.py
+3
-2
tilelang/carver/roller/node.py
tilelang/carver/roller/node.py
+23
-42
tilelang/carver/roller/policy/default.py
tilelang/carver/roller/policy/default.py
+28
-56
tilelang/carver/roller/policy/tensorcore.py
tilelang/carver/roller/policy/tensorcore.py
+16
-39
tilelang/carver/roller/rasterization.py
tilelang/carver/roller/rasterization.py
+0
-2
tilelang/carver/roller/shape_inference/common.py
tilelang/carver/roller/shape_inference/common.py
+1
-4
tilelang/carver/roller/shape_inference/tir.py
tilelang/carver/roller/shape_inference/tir.py
+9
-32
tilelang/carver/template/base.py
tilelang/carver/template/base.py
+9
-4
tilelang/carver/template/conv.py
tilelang/carver/template/conv.py
+17
-8
tilelang/carver/template/flashattention.py
tilelang/carver/template/flashattention.py
+1
-5
tilelang/carver/template/gemv.py
tilelang/carver/template/gemv.py
+3
-6
tilelang/carver/template/general_reduce.py
tilelang/carver/template/general_reduce.py
+4
-6
tilelang/carver/template/matmul.py
tilelang/carver/template/matmul.py
+3
-6
tilelang/carver/utils.py
tilelang/carver/utils.py
+10
-17
tilelang/contrib/cc.py
tilelang/contrib/cc.py
+12
-22
tilelang/contrib/dlpack.py
tilelang/contrib/dlpack.py
+3
-6
No files found.
tilelang/carver/arch/metal.py
View file @
29051439
...
...
@@ -8,7 +8,6 @@ def is_metal_arch(arch: TileDevice) -> bool:
class
METAL
(
TileDevice
):
def
__init__
(
self
,
target
:
Target
|
str
):
if
isinstance
(
target
,
str
):
target
=
Target
(
target
)
...
...
@@ -16,6 +15,6 @@ class METAL(TileDevice):
__all__
=
[
'
is_metal_arch
'
,
'
METAL
'
,
"
is_metal_arch
"
,
"
METAL
"
,
]
tilelang/carver/common_schedules.py
View file @
29051439
...
...
@@ -19,6 +19,7 @@
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
from
typing
import
Callable
from
tvm
import
tir
...
...
tilelang/carver/matmul_analysis.py
View file @
29051439
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
enum
import
Enum
...
...
@@ -157,8 +158,7 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Block
return
block
def
find_arg_idx_from_buffer_chain
(
sch
:
tir
.
Schedule
,
main_block
:
tir
.
schedule
.
BlockRV
,
buffer
:
tir
.
Buffer
)
->
int
:
def
find_arg_idx_from_buffer_chain
(
sch
:
tir
.
Schedule
,
main_block
:
tir
.
schedule
.
BlockRV
,
buffer
:
tir
.
Buffer
)
->
int
:
"""traverse to find the arg index from the buffer"""
producers
=
sch
.
get_producers
(
main_block
)
...
...
@@ -226,9 +226,7 @@ def make_iter_fusion_index_map(
else
:
fused_iters
[
trait
.
kind
]
=
v_i
final_indices
:
list
[
tir
.
PrimExpr
]
=
[
fused_iters
.
get
(
kind
,
tir
.
IntImm
(
traits
[
0
].
extent
.
dtype
,
0
))
for
kind
in
kind_order
]
final_indices
:
list
[
tir
.
PrimExpr
]
=
[
fused_iters
.
get
(
kind
,
tir
.
IntImm
(
traits
[
0
].
extent
.
dtype
,
0
))
for
kind
in
kind_order
]
return
tir
.
IndexMap
(
input_iters
,
final_indices
,
None
)
...
...
@@ -307,8 +305,7 @@ def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
return
A_traits
,
B_traits
,
C_traits
,
block_traits
def
get_index_map
(
block
:
tir
.
Block
,
layout
:
list
[
str
]
|
None
=
None
)
->
tuple
[
tir
.
IndexMap
,
...]
|
None
:
def
get_index_map
(
block
:
tir
.
Block
,
layout
:
list
[
str
]
|
None
=
None
)
->
tuple
[
tir
.
IndexMap
,
...]
|
None
:
"""Get index maps for the block
Parameters
...
...
@@ -343,10 +340,7 @@ def get_index_map(block: tir.Block,
return
axes
def
is_common_reduce
(
var
:
Var
)
->
bool
:
for
iter_var
in
block
.
iter_vars
:
if
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
:
return
True
return
False
return
any
(
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
for
iter_var
in
block
.
iter_vars
)
def
has_common_reduce
(
var
:
Var
)
->
bool
:
vars
=
collect_vars_from_expr
(
var
)
...
...
@@ -384,17 +378,17 @@ def get_index_map(block: tir.Block,
if
kind
==
"C"
:
return
[
IterKind
.
kIter_S
,
primary_iter
,
secondary_iter
]
else
:
return
([
IterKind
.
kIter_S
,
spatial_iter
,
reduction_iter
]
if
check_last_trait
(
region
)
else
[
IterKind
.
kIter_S
,
reduction_iter
,
spatial_iter
])
return
(
[
IterKind
.
kIter_S
,
spatial_iter
,
reduction_iter
]
if
check_last_trait
(
region
)
else
[
IterKind
.
kIter_S
,
reduction_iter
,
spatial_iter
]
)
else
:
raise
ValueError
(
f
"Unknown layout
{
layout
}
"
)
A_index_map
=
make_iter_fusion_index_map
(
A_traits
,
infer_layout
(
layout
[
0
],
block
.
reads
[
0
].
region
,
kind
=
"A"
))
B_index_map
=
make_iter_fusion_index_map
(
B_traits
,
infer_layout
(
layout
[
1
],
block
.
reads
[
1
].
region
,
kind
=
"B"
))
C_index_map
=
make_iter_fusion_index_map
(
C_traits
,
infer_layout
(
layout
[
2
],
block
.
writes
[
0
].
region
,
kind
=
"C"
))
A_index_map
=
make_iter_fusion_index_map
(
A_traits
,
infer_layout
(
layout
[
0
],
block
.
reads
[
0
].
region
,
kind
=
"A"
))
B_index_map
=
make_iter_fusion_index_map
(
B_traits
,
infer_layout
(
layout
[
1
],
block
.
reads
[
1
].
region
,
kind
=
"B"
))
C_index_map
=
make_iter_fusion_index_map
(
C_traits
,
infer_layout
(
layout
[
2
],
block
.
writes
[
0
].
region
,
kind
=
"C"
))
matmul_index_map
=
make_iter_fusion_index_map
(
block_traits
,
...
...
@@ -429,8 +423,7 @@ def get_dequantize_block(sch, blocks) -> BlockRV | None:
has_uint_input
=
any
(
"uint"
in
str
(
region
.
buffer
.
dtype
)
for
region
in
block_stmt
.
reads
)
if
not
has_uint_input
:
return
False
return
not
(
len
(
block_stmt
.
writes
)
!=
1
or
"float"
not
in
str
(
block_stmt
.
writes
[
0
].
buffer
.
dtype
))
return
not
(
len
(
block_stmt
.
writes
)
!=
1
or
"float"
not
in
str
(
block_stmt
.
writes
[
0
].
buffer
.
dtype
))
dequantize_blocks
=
[
block
for
block
in
blocks
if
is_dequantize
(
block
)]
return
dequantize_blocks
[
0
]
if
len
(
dequantize_blocks
)
==
1
else
None
...
...
@@ -452,8 +445,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
return
None
axes
.
extend
(
undefined_vars
(
r
.
min
))
# remove trivial axis
trivial_vars
=
set
(
iter_var
.
var
for
iter_var
in
block_stmt
.
iter_vars
if
_is_one
(
iter_var
.
dom
.
extent
))
trivial_vars
=
set
(
iter_var
.
var
for
iter_var
in
block_stmt
.
iter_vars
if
_is_one
(
iter_var
.
dom
.
extent
))
axes
=
[
axis
for
axis
in
axes
if
axis
not
in
trivial_vars
]
# remove duplicate axis
axes
=
[
var
for
i
,
var
in
enumerate
(
axes
)
if
i
==
0
or
var
!=
axes
[
i
-
1
]]
...
...
@@ -462,8 +454,7 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
lhs_access_vars
=
get_access_vars
(
block_stmt
.
reads
[
0
].
region
)[
-
2
:]
rhs_access_vars
=
get_access_vars
(
block_stmt
.
writes
[
0
].
region
)[
-
2
:]
is_identity
=
list
(
lhs_access_vars
)
==
list
(
rhs_access_vars
)
is_transpose
=
list
(
lhs_access_vars
)
!=
list
(
rhs_access_vars
)
and
set
(
lhs_access_vars
)
==
set
(
rhs_access_vars
)
is_transpose
=
list
(
lhs_access_vars
)
!=
list
(
rhs_access_vars
)
and
set
(
lhs_access_vars
)
==
set
(
rhs_access_vars
)
return
is_identity
,
is_transpose
...
...
@@ -491,9 +482,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]
return
result_blocks
def
normalize_to_matmul
(
sch
:
tir
.
Schedule
,
main_block
:
BlockRV
,
layout
:
list
[
str
]
|
None
=
None
)
->
tir
.
Schedule
|
None
:
def
normalize_to_matmul
(
sch
:
tir
.
Schedule
,
main_block
:
BlockRV
,
layout
:
list
[
str
]
|
None
=
None
)
->
tir
.
Schedule
|
None
:
if
layout
is
None
:
layout
=
[
"n"
,
"t"
,
"n"
]
block_stmt
=
sch
.
get
(
main_block
)
...
...
@@ -543,10 +532,7 @@ def get_tensorized_func_and_tags(
conditions
=
[]
conditions
.
append
(
len
(
block_stmt
.
reads
)
==
2
)
conditions
.
append
(
len
(
block_stmt
.
writes
)
==
1
)
conditions
.
append
(
len
(
collect_block_iter_vars_used_in_access_region
(
block_stmt
,
block_stmt
.
writes
[
0
].
region
))
>
0
)
conditions
.
append
(
len
(
collect_block_iter_vars_used_in_access_region
(
block_stmt
,
block_stmt
.
writes
[
0
].
region
))
>
0
)
return
all
(
conditions
)
# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
...
...
@@ -592,10 +578,7 @@ def get_tensorized_func_and_tags(
return
axes
def
is_common_reduce
(
var
:
Var
)
->
bool
:
for
iter_var
in
block_stmt
.
iter_vars
:
if
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
:
return
True
return
False
return
any
(
iter_var
.
var
==
var
and
iter_var
.
iter_type
==
IterVar
.
CommReduce
for
iter_var
in
block_stmt
.
iter_vars
)
def
has_common_reduce
(
var
:
Var
)
->
bool
:
vars
=
collect_vars_from_expr
(
var
)
...
...
@@ -626,7 +609,7 @@ def get_tensorized_func_and_tags(
# When the func is a dequantize like ops, we should consider the M
require_block_reduce
=
False
# And we only support float16 for now
if
(
hasattr
(
func
.
attrs
,
"dequantize_info"
)
and
in_dtype
in
[
"bfloat16"
,
"float16"
]
)
:
if
hasattr
(
func
.
attrs
,
"dequantize_info"
)
and
in_dtype
in
[
"bfloat16"
,
"float16"
]:
for
arg
in
func
.
params
:
inp_shape
=
func
.
buffer_map
[
arg
].
shape
M
=
inp_shape
[
0
]
...
...
@@ -645,9 +628,7 @@ def get_tensorized_func_and_tags(
if
target
.
kind
.
name
==
"cuda"
and
check_sm_version
(
target
.
arch
)
>=
70
:
in_dtype
,
out_dtype
=
get_in_out_dtypes
(
block_stmt
)
if
not
is_tensorcore_supported_precision
(
in_dtype
,
out_dtype
,
arch
=
get_arch
(
target
)):
logger
.
debug
(
f
"The input and output dtype (
{
in_dtype
}
,
{
out_dtype
}
)is not supported by tensorcore"
)
logger
.
debug
(
f
"The input and output dtype (
{
in_dtype
}
,
{
out_dtype
}
)is not supported by tensorcore"
)
return
func
,
None
# reindex and transform functions
...
...
@@ -676,7 +657,7 @@ def get_tensorized_func_and_tags(
else
:
raise
ValueError
(
f
"Unknown IterVar type
{
iter_type
}
"
)
if
(
isinstance
(
extent
,
tir
.
expr
.
IntImm
)
and
extent
.
value
<
minimal_tensorize_threshold
)
:
if
isinstance
(
extent
,
tir
.
expr
.
IntImm
)
and
extent
.
value
<
minimal_tensorize_threshold
:
return
func
,
None
tags
=
analysis_tensorcore_tags
(
sch
,
main_block
,
target
)
return
sch
.
mod
[
"main"
],
tags
...
...
@@ -686,8 +667,10 @@ def get_tensorized_func_and_tags(
def
get_propagate_map
(
trans
:
bool
=
True
,
dtype
=
"float16"
,
matrix_name
=
"A"
,
index_dtype
=
"int32"
):
from
bitblas.tl.mma_layout
import
(
# pylint: disable=import-outside-toplevel
ldmatrix_32x8_to_shared_16x16_layout
,
ldmatrix_trans_32x8_to_shared_16x16_layout
,
ldmatrix_32x16_to_shared_16x32_layout_a
,
ldmatrix_32x16_to_shared_16x32_layout_b
,
ldmatrix_32x8_to_shared_16x16_layout
,
ldmatrix_trans_32x8_to_shared_16x16_layout
,
ldmatrix_32x16_to_shared_16x32_layout_a
,
ldmatrix_32x16_to_shared_16x32_layout_b
,
)
assert
dtype
in
[
...
...
@@ -727,9 +710,7 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
return
ldmatrix_layout
(
thread_id
,
local_id
)
if
dtype
in
[
"bfloat16"
,
"float16"
]:
ldmatrix_index_map
=
(
ldmatrix_trans_permutation_16x16_32x8_16x16
if
trans
else
ldmatrix_permutation_16x16_32x8_16x16
)
ldmatrix_index_map
=
ldmatrix_trans_permutation_16x16_32x8_16x16
if
trans
else
ldmatrix_permutation_16x16_32x8_16x16
else
:
ldmatrix_index_map
=
ldmatrix_permutation_16x32_32x16_32x16
...
...
@@ -744,7 +725,6 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde
# Ladder weight propagation, which can be used to avoid the ldmatrix
# Instructions.
def
get_ladder_stage3_map
(
dtype
=
"float16"
,
index_dtype
=
"int32"
):
def
shared_32x8_to_mma_32x8_layout
(
i
,
j
):
thread_id
=
(
i
%
8
)
*
4
+
(
j
//
2
)
local_id
=
(
i
//
8
)
*
2
+
(
j
%
2
)
...
...
@@ -837,8 +817,7 @@ def layout_propagate_chain(
scaling_factor
=
1
for
i
,
j
in
zip
(
write
.
buffer
.
shape
,
read
.
buffer
.
shape
):
scaling_factor
*=
i
//
j
final_indices
=
list
(
index_map
.
map_indices
(
tmp_index_map
.
map_indices
(
write_indices
)))
final_indices
=
list
(
index_map
.
map_indices
(
tmp_index_map
.
map_indices
(
write_indices
)))
final_indices
[
-
1
]
=
final_indices
[
-
1
]
//
scaling_factor
index_map
=
IndexMap
(
write_indices
,
...
...
tilelang/carver/roller/bestfit.py
View file @
29051439
...
...
@@ -2,7 +2,6 @@
class
Block
:
def
__init__
(
self
,
start
,
end
,
is_free
):
self
.
start
=
start
self
.
end
=
end
...
...
@@ -21,7 +20,6 @@ class Block:
class
BestFit
:
def
__init__
(
self
,
align
=
32
):
self
.
limit
=
0
self
.
list
=
[]
...
...
@@ -31,16 +29,14 @@ class BestFit:
size
=
(
size
+
self
.
align
-
1
)
//
self
.
align
*
self
.
align
found
=
None
for
block
in
self
.
list
:
if
block
.
is_free
and
block
.
size
()
>=
size
and
(
not
found
or
found
.
size
()
>
block
.
size
()):
if
block
.
is_free
and
block
.
size
()
>=
size
and
(
not
found
or
found
.
size
()
>
block
.
size
()):
found
=
block
if
found
:
found
.
is_free
=
False
remain
=
found
.
size
()
-
size
if
remain
!=
0
:
found
.
end
-=
remain
self
.
list
.
insert
(
self
.
list
.
index
(
found
)
+
1
,
Block
(
found
.
end
,
found
.
end
+
remain
,
True
))
self
.
list
.
insert
(
self
.
list
.
index
(
found
)
+
1
,
Block
(
found
.
end
,
found
.
end
+
remain
,
True
))
return
found
elif
len
(
self
.
list
)
>
0
and
self
.
list
[
-
1
].
is_free
:
add
=
size
-
self
.
list
[
-
1
].
size
()
...
...
tilelang/carver/roller/hint.py
View file @
29051439
"""Hint definition for schedule"""
from
tvm
import
DataType
from
.
import
PrimFuncNode
import
numpy
as
np
...
...
@@ -60,7 +61,7 @@ class Stride:
strided_elem
=
original_shape
else
:
assert
self
.
ax
<
len
(
shape
)
strided_elem
=
np
.
prod
(
shape
[
0
:
self
.
ax
+
1
])
*
self
.
stride
strided_elem
=
np
.
prod
(
shape
[
0
:
self
.
ax
+
1
])
*
self
.
stride
assert
strided_elem
>=
original_shape
return
int
(
strided_elem
)
...
...
@@ -217,7 +218,7 @@ class Hint:
return
dic
@
classmethod
def
from_dict
(
cls
,
dic
:
dict
)
->
'
Hint
'
:
def
from_dict
(
cls
,
dic
:
dict
)
->
"
Hint
"
:
hint
=
cls
()
for
k
,
v
in
dic
.
items
():
setattr
(
hint
,
k
,
v
)
...
...
tilelang/carver/roller/node.py
View file @
29051439
"""PrimFunc Wrapper and Block information Analaysis"""
from
__future__
import
annotations
import
tvm
...
...
@@ -31,7 +32,6 @@ def pre_order_traverse(block_analyzer, blocks, func):
class
BlockAnalyzer
:
def
__init__
(
self
,
sch
)
->
None
:
self
.
sch
:
tir
.
Schedule
=
sch
self
.
block_infos
:
list
[
BlockInfo
]
=
normalize_prim_func
(
self
.
sch
)
...
...
@@ -92,7 +92,6 @@ class Edge:
class
Node
:
def
__init__
(
self
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"Node"
)
->
None
:
self
.
name
=
name
if
tags
is
None
:
...
...
@@ -177,7 +176,6 @@ class Node:
class
PlaceHolderNode
(
Node
):
def
__init__
(
self
,
name
=
""
):
super
().
__init__
(
name
=
"PlaceHolder_"
+
name
)
...
...
@@ -189,11 +187,7 @@ class PlaceHolderNode(Node):
class
PrimFuncNode
(
Node
):
def
__init__
(
self
,
prim_func
:
PrimFunc
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
)
->
None
:
def
__init__
(
self
,
prim_func
:
PrimFunc
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
)
->
None
:
super
().
__init__
(
tags
,
name
=
name
)
self
.
prim_func
=
self
.
_specialize_func
(
prim_func
)
self
.
sch
:
tir
.
Schedule
=
tir
.
Schedule
(
self
.
prim_func
)
...
...
@@ -227,7 +221,7 @@ class PrimFuncNode(Node):
for
dst_id
,
n
in
enumerate
(
inputs
):
if
isinstance
(
n
,
Node
):
n
=
(
n
,
0
)
assert
(
len
(
n
)
==
2
)
assert
len
(
n
)
==
2
src_node
,
src_id
=
n
[
0
],
n
[
1
]
edge
=
Edge
(
src_node
,
self
,
src_id
,
dst_id
)
self
.
_in_edges
.
append
(
edge
)
...
...
@@ -338,9 +332,8 @@ class PrimFuncNode(Node):
if
rstep
is
None
:
rstep
=
{}
shape
=
{
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
].
name
:
[
tvm
.
arith
.
ConstIntBound
(
0
,
val
-
1
)
for
val
in
tile
]
for
block
in
self
.
schedule_stages
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
].
name
:
[
tvm
.
arith
.
ConstIntBound
(
0
,
val
-
1
)
for
val
in
tile
]
for
block
in
self
.
schedule_stages
}
return
self
.
ana
.
infer
(
shape
,
rstep
,
targets
)
...
...
@@ -356,10 +349,7 @@ class PrimFuncNode(Node):
results
.
append
(
shapes
[
arg
.
name
])
continue
# should not exceed original shape
trimmed_shape
=
[
self
.
extent_wrapper
(
i
)
for
i
in
list
(
map
(
min
,
zip
(
shapes
[
arg
.
name
],
self
.
input_buffers
[
i
].
shape
)))
]
trimmed_shape
=
[
self
.
extent_wrapper
(
i
)
for
i
in
list
(
map
(
min
,
zip
(
shapes
[
arg
.
name
],
self
.
input_buffers
[
i
].
shape
)))]
results
.
append
(
trimmed_shape
)
return
results
...
...
@@ -380,10 +370,8 @@ class PrimFuncNode(Node):
propagate_shape
=
shapes
[
arg
.
name
]
buffer_shape
=
args
[
i
].
shape
if
len
(
buffer_shape
)
>
len
(
propagate_shape
):
buffer_shape
=
buffer_shape
[
-
len
(
propagate_shape
):]
trimmed_shape
=
[
self
.
extent_wrapper
(
j
)
for
j
in
list
(
map
(
min
,
zip
(
propagate_shape
,
buffer_shape
)))
]
buffer_shape
=
buffer_shape
[
-
len
(
propagate_shape
)
:]
trimmed_shape
=
[
self
.
extent_wrapper
(
j
)
for
j
in
list
(
map
(
min
,
zip
(
propagate_shape
,
buffer_shape
)))]
results
.
append
(
trimmed_shape
)
return
results
...
...
@@ -412,10 +400,7 @@ class PrimFuncNode(Node):
def
get_reduce_inputs_dtype
(
self
):
if
self
.
reduction_block
is
None
:
return
{}
return
{
b
.
name
:
tvm
.
DataType
(
b
.
dtype
)
for
b
in
self
.
block_analyzer
.
get_input_buffers
(
self
.
reduction_block
)
}
return
{
b
.
name
:
tvm
.
DataType
(
b
.
dtype
)
for
b
in
self
.
block_analyzer
.
get_input_buffers
(
self
.
reduction_block
)}
@
functools
.
lru_cache
def
infer_tensorcore_axis
(
self
)
->
tuple
[
int
]:
...
...
@@ -425,8 +410,7 @@ class PrimFuncNode(Node):
C_ax_m
,
C_ax_n
=
self
.
get_tag
(
"tensorcore_config"
)
wmma_m
,
wmma_n
,
wmma_k
=
[
16
,
16
,
16
]
# just for testing, any number is ok
output_buffer_shape
=
(
self
.
block_analyzer
.
sch
.
get
(
self
.
reduction_block
).
writes
[
0
].
buffer
.
shape
)
output_buffer_shape
=
self
.
block_analyzer
.
sch
.
get
(
self
.
reduction_block
).
writes
[
0
].
buffer
.
shape
valid_region
=
[]
for
region
in
output_buffer_shape
:
if
region
.
value
==
1
:
...
...
@@ -438,8 +422,7 @@ class PrimFuncNode(Node):
def
get_cl_shapes
(
c_ax_m
,
c_ax_n
,
num_nvalid_regions
):
spatial_dim
=
self
.
get_space_dim
()
assert
len
(
valid_region
)
==
len
(
spatial_dim
),
f
"
{
valid_region
}
mismatch with
{
spatial_dim
}
"
assert
len
(
valid_region
)
==
len
(
spatial_dim
),
f
"
{
valid_region
}
mismatch with
{
spatial_dim
}
"
cl_shapes
=
[
1
]
*
len
(
spatial_dim
)
cl_shapes
[
c_ax_m
-
num_nvalid_regions
]
=
wmma_m
cl_shapes
[
c_ax_n
-
num_nvalid_regions
]
=
wmma_n
...
...
@@ -467,9 +450,11 @@ class PrimFuncNode(Node):
shapes
,
_
=
self
.
propagate
(
shape
,
rstep
)
def
is_broadcast_pattern
(
buffer
,
output_buffer
):
return
(
buffer
in
self
.
args
and
len
(
shapes
[
output_buffer
.
name
])
>
len
(
shapes
[
buffer
.
name
])
and
np
.
prod
(
shapes
[
output_buffer
.
name
])
>
np
.
prod
(
shapes
[
buffer
.
name
]))
return
(
buffer
in
self
.
args
and
len
(
shapes
[
output_buffer
.
name
])
>
len
(
shapes
[
buffer
.
name
])
and
np
.
prod
(
shapes
[
output_buffer
.
name
])
>
np
.
prod
(
shapes
[
buffer
.
name
])
)
def
is_after_reduce_stage
(
block
):
if
not
self
.
reduction_block
:
...
...
@@ -491,8 +476,8 @@ class PrimFuncNode(Node):
output_buffer
=
self
.
block_analyzer
.
get_output_buffers
(
block
)[
0
]
for
buffer
in
self
.
block_analyzer
.
get_input_buffers
(
block
):
cache
=
buffer
.
name
not
in
cached_tensor
and
(
is_broadcast_pattern
(
buffer
,
output_buffer
)
or
self
.
block_analyzer
.
get_block_info
(
block
).
is_reduction
()
)
is_broadcast_pattern
(
buffer
,
output_buffer
)
or
self
.
block_analyzer
.
get_block_info
(
block
).
is_reduction
()
)
if
not
cache
:
continue
cached_tensor
.
append
(
buffer
.
name
)
...
...
@@ -500,8 +485,7 @@ class PrimFuncNode(Node):
continue
# cache after reduce op can often reuse buffer in reduce stage
if
buffer
.
name
in
stride_map
:
num_elem
=
stride_map
[
buffer
.
name
].
compute_elements_from_shape
(
shapes
[
buffer
.
name
])
num_elem
=
stride_map
[
buffer
.
name
].
compute_elements_from_shape
(
shapes
[
buffer
.
name
])
else
:
num_elem
=
np
.
prod
(
shapes
[
buffer
.
name
])
buffer_len
=
num_elem
*
int
((
tvm
.
DataType
(
buffer
.
dtype
).
bits
+
7
)
//
8
)
...
...
@@ -514,7 +498,6 @@ class PrimFuncNode(Node):
class
OutputNode
(
Node
):
def
__init__
(
self
,
node
,
id
=
0
):
super
().
__init__
(
name
=
"OutputNode"
)
# connect node and output node
...
...
@@ -549,15 +532,16 @@ def topo_order(list_of_nodes) -> list[Node]:
input_ready_count
[
dst_node
]
=
len
(
dst_node
.
inputs
)
list_of_nodes
.
append
(
dst_node
)
input_ready_count
[
dst_node
]
-=
1
assert
(
input_ready_count
[
dst_node
]
>=
0
)
assert
input_ready_count
[
dst_node
]
>=
0
if
input_ready_count
[
dst_node
]
==
0
:
ready
.
append
(
dst_node
)
assert
(
len
(
list_of_nodes
)
==
len
(
output_list
)
)
assert
len
(
list_of_nodes
)
==
len
(
output_list
)
return
output_list
def
find_topo_sort_priority
(
output_node_list
)
->
list
[
Node
]:
import
sys
sys
.
setrecursionlimit
(
10000
)
def
topo_sort_get_layer
(
node
,
topo_layer
):
...
...
@@ -576,9 +560,7 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
if
node
in
visited
:
return
visited
.
add
(
node
)
ordered_input_nodes
=
sorted
([
edge
.
src_node
for
edge
in
node
.
inputs
],
key
=
lambda
n
:
topo_layer
[
n
],
reverse
=
True
)
ordered_input_nodes
=
sorted
([
edge
.
src_node
for
edge
in
node
.
inputs
],
key
=
lambda
n
:
topo_layer
[
n
],
reverse
=
True
)
for
n
in
ordered_input_nodes
:
topo_sort_dfs
(
n
,
visited
,
topo_order
)
topo_order
.
append
(
node
)
...
...
@@ -591,7 +573,6 @@ def find_topo_sort_priority(output_node_list) -> list[Node]:
def
find_topo_sort
(
output_node_list
)
->
list
[
Node
]:
def
topo_sort_dfs
(
node
,
visited
,
topo_order
):
if
node
in
visited
:
return
...
...
tilelang/carver/roller/policy/default.py
View file @
29051439
"""Policy for cuda core schedule"""
from
__future__
import
annotations
import
functools
import
math
...
...
@@ -36,20 +37,14 @@ class DefaultPolicy:
self
.
rasterization
=
NoRasterization
()
@
classmethod
def
from_prim_func
(
cls
,
func
:
tvm
.
tir
.
PrimFunc
,
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
):
def
from_prim_func
(
cls
,
func
:
tvm
.
tir
.
PrimFunc
,
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
,
name
:
str
=
"PrimFuncNode"
):
return
cls
(
arch
,
tags
).
_init_with_prim_func
(
func
,
name
)
@
classmethod
def
from_output_nodes
(
cls
,
nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
tags
:
dict
|
None
=
None
):
return
cls
(
arch
,
tags
).
_init_with_output_nodes
(
nodes
)
def
_init_with_prim_func
(
self
,
func
:
tvm
.
tir
.
PrimFunc
,
name
:
str
=
"PrimFuncNode"
)
->
DefaultPolicy
:
def
_init_with_prim_func
(
self
,
func
:
tvm
.
tir
.
PrimFunc
,
name
:
str
=
"PrimFuncNode"
)
->
DefaultPolicy
:
if
func
is
not
None
and
isinstance
(
func
,
tvm
.
tir
.
PrimFunc
):
self
.
func
=
func
self
.
prim_func_node
=
PrimFuncNode
(
self
.
func
,
tags
=
self
.
tags
,
name
=
name
)
...
...
@@ -60,9 +55,7 @@ class DefaultPolicy:
return
self
def
_init_with_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
]):
self
.
ordered_nodes
=
list
(
filter
(
lambda
n
:
not
n
.
is_placeholder
()
and
not
n
.
is_output
(),
find_topo_sort
(
output_nodes
)))
self
.
ordered_nodes
=
list
(
filter
(
lambda
n
:
not
n
.
is_placeholder
()
and
not
n
.
is_output
(),
find_topo_sort
(
output_nodes
)))
for
node
in
self
.
ordered_nodes
:
node
.
update_tags
(
self
.
tags
)
...
...
@@ -102,13 +95,14 @@ class DefaultPolicy:
def
dfs_smem_tile
(
self
,
init_tile
,
rstep_map
)
->
Iterable
[
TileDict
]:
_steps
=
[
get_all_factors
(
n
)
for
n
in
self
.
output_nodes
[
0
].
get_space_dim
()]
steps
=
[
step
[
step
.
index
(
t
):]
for
step
,
t
in
zip
(
_steps
,
init_tile
)]
steps
=
[
step
[
step
.
index
(
t
)
:]
for
step
,
t
in
zip
(
_steps
,
init_tile
)]
for
i
in
range
(
len
(
steps
)):
added
=
list
(
filter
(
lambda
s
:
s
<
steps
[
i
][
-
1
]
and
s
>
steps
[
i
][
0
]
and
s
not
in
steps
[
i
],
[
2
,
4
,
8
,
16
,
32
],
))
)
)
steps
[
i
].
extend
(
added
)
steps
[
i
]
=
sorted
(
steps
[
i
])
visited_tiles
=
{}
...
...
@@ -190,10 +184,7 @@ class DefaultPolicy:
"""
tile_map
=
{}
for
node
in
self
.
output_nodes
:
tile_map
[
node
]
=
[
tile
[
i
]
*
node
.
get_space_dim
()[
i
]
//
self
.
output_nodes
[
0
].
get_space_dim
()[
i
]
for
i
in
range
(
len
(
tile
))
]
tile_map
[
node
]
=
[
tile
[
i
]
*
node
.
get_space_dim
()[
i
]
//
self
.
output_nodes
[
0
].
get_space_dim
()[
i
]
for
i
in
range
(
len
(
tile
))]
return
tile_map
def
compute_workload_per_item
(
self
,
output_tile
)
->
float
:
...
...
@@ -304,8 +295,7 @@ class DefaultPolicy:
score
=
0
shape
=
node
.
propagate_inputs
(
tile
,
rstep
=
rstep
)
for
i
,
input_buffer
in
enumerate
(
node
.
input_buffers
):
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
(
(
node
.
get_buffer_dtype
(
input_buffer
).
bits
+
7
)
//
8
)
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
((
node
.
get_buffer_dtype
(
input_buffer
).
bits
+
7
)
//
8
)
score
+=
sim
(
int
(
coalesced_factor
(
shape
[
i
],
input_buffer
.
shape
)),
read_transaction_elements
,
...
...
@@ -380,17 +370,13 @@ class DefaultPolicy:
return
None
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
new_rstep_map
=
rstep_map
.
copy
()
while
True
:
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
if
new_rstep_id
is
None
:
break
new_rstep_map
[
node
]
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
new_rstep_map
[
node
]
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
old_rstep_map
=
td
.
rstep_map
td
.
rstep_map
=
new_rstep_map
smem_usage
,
_
=
self
.
_compute_shared_memory_usage
(
td
)
...
...
@@ -434,15 +420,14 @@ class DefaultPolicy:
if
edge
.
src_node
.
is_placeholder
():
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
read_transaction_elements
=
self
.
arch
.
transaction_size
[
1
]
//
nbytes
traffic
+=
coalesced_tensor_shape
(
input_shapes
[
i
],
edge
.
src_node
.
get_shape
(),
read_transaction_elements
)
*
nbytes
traffic
+=
coalesced_tensor_shape
(
input_shapes
[
i
],
edge
.
src_node
.
get_shape
(),
read_transaction_elements
)
*
nbytes
for
edge
in
node
.
outputs
:
if
edge
.
dst_node
.
is_output
():
nbytes
=
(
edge
.
src_node
.
get_dtype
().
bits
+
7
)
//
8
write_transaction_elements
=
self
.
arch
.
transaction_size
[
0
]
//
nbytes
traffic
+=
coalesced_tensor_shape
(
output_shapes
[
edge
.
src_id
],
node
.
get_shape
(
edge
.
src_id
),
write_transaction_elements
)
*
nbytes
traffic
+=
(
coalesced_tensor_shape
(
output_shapes
[
edge
.
src_id
],
node
.
get_shape
(
edge
.
src_id
),
write_transaction_elements
)
*
nbytes
)
return
traffic
,
op_tile_map
...
...
@@ -487,10 +472,7 @@ class DefaultPolicy:
cached_tensors_map
=
{}
def
can_free
(
node
,
out_id
):
for
edge
in
node
.
outputs
:
if
edge
.
src_id
==
out_id
and
edge
.
dst_node
not
in
processed
:
return
False
return
True
return
all
(
not
(
edge
.
src_id
==
out_id
and
edge
.
dst_node
not
in
processed
)
for
edge
in
node
.
outputs
)
for
node
in
self
.
ordered_nodes
:
node_internal_bytes
,
cached_tensors_map
[
node
]
=
self
.
infer_node_smem_usage
(
td
,
node
)
...
...
@@ -528,9 +510,7 @@ class DefaultPolicy:
Tuple[Dict, Dict]
A tuple of dictionaries containing the output strides and tensor strides.
"""
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)
}
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)}
tensor_strides
=
{}
return
output_strides
,
tensor_strides
...
...
@@ -551,8 +531,7 @@ class DefaultPolicy:
output_strides_map
=
{}
tensor_strides_map
=
{}
for
node
in
self
.
ordered_nodes
:
output_strides_map
[
node
],
tensor_strides_map
[
node
]
=
self
.
compute_node_stride_map
(
node
,
td
)
output_strides_map
[
node
],
tensor_strides_map
[
node
]
=
self
.
compute_node_stride_map
(
node
,
td
)
td
.
output_strides_map
,
td
.
tensor_strides_map
=
output_strides_map
,
tensor_strides_map
def
compute_tile_dict
(
self
,
output_tile
:
list
[
int
],
rstep_map
)
->
TileDict
:
...
...
@@ -582,9 +561,7 @@ class DefaultPolicy:
output_shape
=
self
.
output_nodes
[
0
].
get_space_dim
()
td
.
grid_size
=
int
(
np
.
prod
([(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
output_tile
,
output_shape
)]))
# estimated reg usage
reg_usage
=
int
(
2
*
max
([
np
.
prod
(
td
.
get_tile
(
node
))
*
node
.
get_dtype
().
bits
/
32
for
node
in
self
.
ordered_nodes
]))
reg_usage
=
int
(
2
*
max
([
np
.
prod
(
td
.
get_tile
(
node
))
*
node
.
get_dtype
().
bits
/
32
for
node
in
self
.
ordered_nodes
]))
if
reg_usage
>
self
.
arch
.
reg_cap
:
td
.
valid
=
False
return
td
...
...
@@ -609,13 +586,10 @@ class DefaultPolicy:
for
node
in
self
.
ordered_nodes
:
if
np
.
prod
(
td
.
get_tile
(
node
))
==
0
:
return
False
node_grid_size
=
np
.
prod
([
(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
td
.
get_tile
(
node
),
node
.
get_space_dim
())
])
node_grid_size
=
np
.
prod
([(
y
+
x
-
1
)
//
x
for
x
,
y
in
zip
(
td
.
get_tile
(
node
),
node
.
get_space_dim
())])
if
node_grid_size
!=
td
.
grid_size
:
return
False
if
(
hasattr
(
node
,
"reduce_op"
)
and
node
.
reduce_op
is
not
None
and
len
(
node
.
reduce_op
.
axis
)
==
len
(
td
.
output_tile
)):
if
hasattr
(
node
,
"reduce_op"
)
and
node
.
reduce_op
is
not
None
and
len
(
node
.
reduce_op
.
axis
)
==
len
(
td
.
output_tile
):
for
i
,
tile_extent
in
enumerate
(
td
.
output_tile
):
if
node
.
reduce_op
.
axis
[
i
].
dom
.
extent
%
tile_extent
:
return
False
...
...
@@ -639,23 +613,22 @@ class DefaultPolicy:
node_space_sizes
=
[
int
(
np
.
prod
(
td
.
get_tile
(
node
)))
for
node
in
self
.
ordered_nodes
]
max_block_size
=
functools
.
reduce
(
math
.
gcd
,
node_space_sizes
)
if
max_block_size
<
self
.
arch
.
warp_size
*
self
.
arch
.
sm_partition
and
max_block_size
==
min
(
node_space_sizes
):
node_reduce_sizes
=
[
int
(
np
.
prod
(
list
(
td
.
get_rstep
(
node
).
values
())))
for
node
in
self
.
ordered_nodes
]
if
max_block_size
<
self
.
arch
.
warp_size
*
self
.
arch
.
sm_partition
and
max_block_size
==
min
(
node_space_sizes
):
node_reduce_sizes
=
[
int
(
np
.
prod
(
list
(
td
.
get_rstep
(
node
).
values
())))
for
node
in
self
.
ordered_nodes
]
total_sizes
=
[
x
*
y
for
x
,
y
in
zip
(
node_space_sizes
,
node_reduce_sizes
)]
max_possible_size
=
functools
.
reduce
(
math
.
gcd
,
total_sizes
)
possible_block_sizes
=
list
(
filter
(
lambda
x
:
x
%
max_block_size
==
0
and
x
<=
1024
,
get_all_factors
(
max_possible_size
),
))
)
)
possible_block_sizes
=
list
(
filter
(
# either be a factor of space or cover fully cover the space
lambda
x
:
all
([
x
%
s
==
0
or
s
%
x
==
0
for
s
in
node_space_sizes
]),
possible_block_sizes
,
))
)
)
factor_ordered
=
sorted
(
possible_block_sizes
,
key
=
self
.
score_block_size
)
return
factor_ordered
else
:
...
...
@@ -821,8 +794,7 @@ class DefaultPolicy:
vectorize_result
=
{}
for
tensor
,
shape
in
shapes
.
items
():
for
v
in
vectorize_sizes
:
if
(
is_shape_aligned
(
shape
,
block_size
*
v
)
and
is_cont
(
shape
,
v
)
and
is_type_allowed
(
dtypes
[
tensor
],
v
)):
if
is_shape_aligned
(
shape
,
block_size
*
v
)
and
is_cont
(
shape
,
v
)
and
is_type_allowed
(
dtypes
[
tensor
],
v
):
vectorize_result
[
tensor
]
=
v
break
return
vectorize_result
...
...
tilelang/carver/roller/policy/tensorcore.py
View file @
29051439
"""Policy for tensorcore schedule"""
from
__future__
import
annotations
import
tvm
import
numpy
as
np
...
...
@@ -13,7 +14,6 @@ logger = logging.getLogger(__name__)
class
TensorCorePolicy
(
DefaultPolicy
):
# this is the trick for wmma.
# However, for int8 mma, the wmma_k should be 32.
wmma_k
:
int
=
16
...
...
@@ -70,9 +70,9 @@ class TensorCorePolicy(DefaultPolicy):
A_high_ax
=
min
(
A_ax_m
,
A_ax_k
)
B_high_ax
=
min
(
B_ax_n
,
B_ax_k
)
C_high_ax
=
min
(
C_ax_m
,
C_ax_n
)
A_stride
=
Stride
(
stride
=
np
.
prod
(
AS_shape
[
A_high_ax
+
1
:])
+
offset
,
ax
=
A_high_ax
)
B_stride
=
Stride
(
stride
=
np
.
prod
(
BS_shape
[
B_high_ax
+
1
:])
+
offset
,
ax
=
B_high_ax
)
C_stride
=
Stride
(
stride
=
np
.
prod
(
CS_shape
[
C_high_ax
+
1
:])
+
offset
,
ax
=
C_high_ax
)
A_stride
=
Stride
(
stride
=
np
.
prod
(
AS_shape
[
A_high_ax
+
1
:])
+
offset
,
ax
=
A_high_ax
)
B_stride
=
Stride
(
stride
=
np
.
prod
(
BS_shape
[
B_high_ax
+
1
:])
+
offset
,
ax
=
B_high_ax
)
C_stride
=
Stride
(
stride
=
np
.
prod
(
CS_shape
[
C_high_ax
+
1
:])
+
offset
,
ax
=
C_high_ax
)
return
A_stride
,
B_stride
,
C_stride
def
infer_node_smem_usage
(
self
,
td
:
TileDict
,
node
:
PrimFuncNode
):
...
...
@@ -86,8 +86,7 @@ class TensorCorePolicy(DefaultPolicy):
# get reduce input size
target_transaction
=
self
.
arch
.
transaction_size
[
0
]
*
2
# 512 bytes // type bits
reduce_input_dtype
=
node
.
get_buffer_dtype
(
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)[
0
])
reduce_input_dtype
=
node
.
get_buffer_dtype
(
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)[
0
])
basic
=
(
target_transaction
*
8
)
//
reduce_input_dtype
.
bits
result
=
{}
...
...
@@ -95,7 +94,7 @@ class TensorCorePolicy(DefaultPolicy):
iter_name
=
iter_info
.
var
.
name
iter_dom
=
iter_info
.
dom
.
extent
if
iter_dom
%
16
>
0
:
result
[
iter_name
]
=
(
16
if
iter_dom
<
basic
else
basic
)
# for the case of padding
result
[
iter_name
]
=
16
if
iter_dom
<
basic
else
basic
# for the case of padding
elif
iter_dom
%
basic
==
0
:
result
[
iter_name
]
=
basic
else
:
...
...
@@ -114,7 +113,6 @@ class TensorCorePolicy(DefaultPolicy):
return
False
if
_check_small_tile
(
td
):
smem_limit
=
min
(
self
.
arch
.
max_smem_usage
//
td
.
block_per_SM
,
self
.
arch
.
smem_cap
)
rstep_map
=
td
.
rstep_map
.
copy
()
...
...
@@ -127,13 +125,10 @@ class TensorCorePolicy(DefaultPolicy):
return
rstep
def
_shared_memory_usage
(
td
:
TileDict
):
return
node
.
footprint
(
td
.
output_tile
,
new_rstep_map
,
td
.
tensor_strides_map
[
node
])
return
node
.
footprint
(
td
.
output_tile
,
new_rstep_map
,
td
.
tensor_strides_map
[
node
])
def
_score
(
rstep_id
):
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
score
=
0
shape
=
node
.
propagate_inputs_on_reduction
(
td
.
get_tile
(
node
),
rstep
=
rstep
)
input_buffers
=
node
.
block_analyzer
.
get_input_buffers
(
node
.
reduction_block
)
...
...
@@ -153,18 +148,13 @@ class TensorCorePolicy(DefaultPolicy):
return
None
return
max
(
candidates
,
key
=
lambda
x
:
x
[
1
])[
0
]
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
cur_rstep_id
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
].
index
(
rstep
[
k
.
var
.
name
])
for
k
in
node
.
raxis
}
new_rstep_map
=
rstep_map
.
copy
()
while
True
:
new_rstep_id
=
_enlarge
(
cur_rstep_id
)
if
new_rstep_id
is
None
:
break
new_rstep_map
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
new_rstep_map
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
new_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
old_rstep_map
=
td
.
rstep_map
td
.
rstep_map
=
new_rstep_map
smem_usage
,
_
=
_shared_memory_usage
(
td
)
...
...
@@ -173,9 +163,7 @@ class TensorCorePolicy(DefaultPolicy):
break
else
:
cur_rstep_id
=
new_rstep_id
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
cur_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
rstep
=
{
k
.
var
.
name
:
all_steps
[
k
.
var
.
name
][
cur_rstep_id
[
k
.
var
.
name
]]
for
k
in
node
.
raxis
}
return
rstep
for
node
in
self
.
ordered_nodes
:
...
...
@@ -206,11 +194,7 @@ class TensorCorePolicy(DefaultPolicy):
return
super
().
get_node_reduce_step_candidates
(
node
)
else
:
# must be a a multiple of wmma_k
return
{
k
.
var
.
name
:
[
x
*
self
.
wmma_k
for
x
in
get_all_factors
(
int
(
k
.
dom
.
extent
)
//
self
.
wmma_k
)
]
for
k
in
node
.
raxis
}
return
{
k
.
var
.
name
:
[
x
*
self
.
wmma_k
for
x
in
get_all_factors
(
int
(
k
.
dom
.
extent
)
//
self
.
wmma_k
)]
for
k
in
node
.
raxis
}
def
check_tile_shape_isvalid
(
self
,
td
:
TileDict
):
for
node
in
self
.
ordered_nodes
:
...
...
@@ -221,10 +205,7 @@ class TensorCorePolicy(DefaultPolicy):
td
.
tile_map
[
node
][
ax_n
],
)
# check the tile size is valid
wmma_invalid
=
[
block_m
<
wmma_m
or
block_n
<
wmma_n
for
wmma_m
,
wmma_n
in
self
.
arch
.
get_avaliable_tensorintrin_shapes
()
]
wmma_invalid
=
[
block_m
<
wmma_m
or
block_n
<
wmma_n
for
wmma_m
,
wmma_n
in
self
.
arch
.
get_avaliable_tensorintrin_shapes
()]
if
all
(
wmma_invalid
):
return
False
if
any
([
y
%
x
for
x
,
y
in
zip
(
td
.
tile_map
[
node
],
node
.
get_space_dim
())]):
...
...
@@ -242,13 +223,10 @@ class TensorCorePolicy(DefaultPolicy):
return
super
().
compute_node_stride_map
(
node
,
td
)
use_layout
=
self
.
_can_implement_layout
(
node
,
td
)
AS_stride
,
BS_stride
,
C_stride
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
),
td
.
get_rstep
(
node
))
AS_stride
,
BS_stride
,
C_stride
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
),
td
.
get_rstep
(
node
))
A_stride
,
B_stride
,
_
=
self
.
_compute_tc_strides
(
node
,
td
.
get_tile
(
node
))
tensor_strides
=
{}
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)
}
output_strides
=
{
int
(
i
+
len
(
node
.
input_buffers
)):
Stride
()
for
i
,
_
in
enumerate
(
node
.
output_buffers
)}
tensor_strides
=
{}
# when connected to shared input, should use full stride without rstep
for
i
,
(
_
,
_
)
in
enumerate
(
zip
([
AS_stride
,
BS_stride
],
[
A_stride
,
B_stride
])):
...
...
@@ -347,8 +325,7 @@ class TensorCorePolicy(DefaultPolicy):
overall_gmem_size_in_bytes
:
int
=
0
for
node
in
self
.
ordered_nodes
:
for
buffer
in
node
.
input_buffers
:
overall_gmem_size_in_bytes
+=
(
int
(
np
.
prod
(
buffer
.
shape
))
*
tvm
.
DataType
(
buffer
.
dtype
).
bits
//
8
)
overall_gmem_size_in_bytes
+=
int
(
np
.
prod
(
buffer
.
shape
))
*
tvm
.
DataType
(
buffer
.
dtype
).
bits
//
8
return
overall_gmem_size_in_bytes
<
self
.
arch
.
l2_cache_size_bytes
conditions
.
append
(
_check_memory_size
())
...
...
tilelang/carver/roller/rasterization.py
View file @
29051439
...
...
@@ -2,7 +2,6 @@
class
Rasterization
:
panel_width_
=
None
def
__init__
(
self
)
->
None
:
...
...
@@ -18,7 +17,6 @@ class Rasterization:
class
NoRasterization
(
Rasterization
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
tilelang/carver/roller/shape_inference/common.py
View file @
29051439
...
...
@@ -4,9 +4,7 @@ from tvm import arith
class
Statement
:
def
__init__
(
self
,
output
:
str
,
dependent_region
:
dict
,
var_map
:
OrderedDict
,
range_map
:
OrderedDict
):
def
__init__
(
self
,
output
:
str
,
dependent_region
:
dict
,
var_map
:
OrderedDict
,
range_map
:
OrderedDict
):
self
.
output
=
output
self
.
dependent_region
=
dependent_region
self
.
var_map
=
var_map
...
...
@@ -18,7 +16,6 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
class
InputShapeInference
:
def
__init__
(
self
,
deps
:
list
[
Statement
]):
self
.
deps
=
deps
...
...
tilelang/carver/roller/shape_inference/tir.py
View file @
29051439
...
...
@@ -5,7 +5,6 @@ from tvm import arith, tir
class
Statement
:
def
__init__
(
self
,
block_analyzer
,
block
:
BlockRV
):
self
.
block_analyzer
=
block_analyzer
self
.
block
=
block
...
...
@@ -21,9 +20,7 @@ class Statement:
if
len
(
self
.
dependent_region
[
input_name
])
!=
1
:
return
None
indices
=
self
.
dependent_region
[
input_name
][
0
]
iter_map_range
=
{
_iter
.
var
:
_iter
.
dom
for
_iter
in
self
.
block_analyzer
.
get_spatial_axis
(
self
.
block
)
}
iter_map_range
=
{
_iter
.
var
:
_iter
.
dom
for
_iter
in
self
.
block_analyzer
.
get_spatial_axis
(
self
.
block
)}
iter_map_result
=
arith
.
detect_iter_map
(
indices
,
iter_map_range
,
...
...
@@ -77,7 +74,6 @@ class TensorDepNode:
class
DependencyAnalysis
:
def
__init__
(
self
,
deps
):
self
.
deps
=
deps
# issue: duplicate name when we have two same ops.
...
...
@@ -112,8 +108,7 @@ class DependencyAnalysis:
def
traverse_dependencies
(
self
,
compute
):
if
isinstance
(
compute
,
Statement
):
node
=
self
.
get_or_create_node
(
compute
.
block_analyzer
.
get_output_buffers
(
compute
.
block
)[
0
].
name
)
node
=
self
.
get_or_create_node
(
compute
.
block_analyzer
.
get_output_buffers
(
compute
.
block
)[
0
].
name
)
# Loop through input tensors
for
input_buffer
in
compute
.
block_analyzer
.
get_input_buffers
(
compute
.
block
):
# Get the input node
...
...
@@ -167,7 +162,6 @@ class DependencyAnalysis:
class
InputShapeInference
:
def
__init__
(
self
,
deps
:
list
[
Statement
]):
self
.
deps
=
deps
self
.
target_mapping
=
{}
...
...
@@ -183,16 +177,11 @@ class InputShapeInference:
if
targets
in
self
.
target_mapping
:
return
self
.
target_mapping
[
targets
]
# should be buffer name instead of block name
name2dep
=
{
dep
.
block_analyzer
.
get_output_buffers
(
dep
.
block
)[
0
].
name
:
dep
for
dep
in
self
.
deps
}
name2dep
=
{
dep
.
block_analyzer
.
get_output_buffers
(
dep
.
block
)[
0
].
name
:
dep
for
dep
in
self
.
deps
}
mapping
=
{}
input_vars
=
[]
for
target
in
targets
:
vars
=
[
iter
.
var
for
iter
in
name2dep
[
target
].
block_analyzer
.
get_spatial_axis
(
name2dep
[
target
].
block
)
]
vars
=
[
iter
.
var
for
iter
in
name2dep
[
target
].
block_analyzer
.
get_spatial_axis
(
name2dep
[
target
].
block
)]
input_vars
.
append
(
vars
)
mapping
[
target
]
=
[
vars
]
ana
=
arith
.
Analyzer
()
...
...
@@ -221,13 +210,8 @@ class InputShapeInference:
mapping
[
input_name
]
=
[]
for
indices
in
indices_list
:
for
region
in
regions
:
vmap
=
{
k
:
(
tir
.
Cast
(
k
.
dtype
,
v
)
if
v
.
dtype
!=
k
.
dtype
else
v
)
for
k
,
v
in
zip
(
ax_vars
,
indices
)
}
region
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
ax
,
vmap
))
for
ax
in
region
]
vmap
=
{
k
:
(
tir
.
Cast
(
k
.
dtype
,
v
)
if
v
.
dtype
!=
k
.
dtype
else
v
)
for
k
,
v
in
zip
(
ax_vars
,
indices
)}
region
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
ax
,
vmap
))
for
ax
in
region
]
if
not
region_exist_in_list
(
region
,
mapping
[
input_name
]):
mapping
[
input_name
].
append
(
region
)
buffers
=
[]
...
...
@@ -241,10 +225,7 @@ class InputShapeInference:
self
.
target_mapping
[
targets
]
=
input_vars
,
mapping
return
input_vars
,
mapping
def
infer
(
self
,
shape
:
dict
[
str
,
list
[
arith
.
ConstIntBound
]],
rstep
:
dict
[
str
,
int
]
=
None
,
targets
=
None
):
def
infer
(
self
,
shape
:
dict
[
str
,
list
[
arith
.
ConstIntBound
]],
rstep
:
dict
[
str
,
int
]
=
None
,
targets
=
None
):
if
rstep
is
None
:
rstep
=
{}
compute_targets
=
tuple
(
shape
.
keys
())
...
...
@@ -258,8 +239,7 @@ class InputShapeInference:
for
ax
in
self
.
reduce_axes
:
# assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value.
if
ax
.
var
.
name
in
rstep
:
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
min
(
ax
.
dom
.
extent
,
rstep
[
ax
.
var
.
name
])
-
1
))
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
min
(
ax
.
dom
.
extent
,
rstep
[
ax
.
var
.
name
])
-
1
))
else
:
bound
=
arith
.
ConstIntBound
(
int
(
ax
.
dom
.
min
),
int
(
ax
.
dom
.
min
+
ax
.
dom
.
extent
-
1
))
ana
.
update
(
ax
.
var
,
bound
,
True
)
...
...
@@ -312,14 +292,11 @@ class InputShapeInference:
for
name
,
regions
in
mapping
.
items
():
region
=
regions
[
0
]
result
[
name
]
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
index
,
vmap
))
for
index
in
region
]
result
[
name
]
=
[
ana
.
simplify
(
tir
.
stmt_functor
.
substitute
(
index
,
vmap
))
for
index
in
region
]
return
result
def
region_exist_in_list
(
a
,
list
)
->
bool
:
def
expr_is_same
(
a
,
b
)
->
bool
:
if
isinstance
(
a
,
tir
.
IntImm
)
and
isinstance
(
b
,
tir
.
IntImm
):
return
a
.
value
==
b
.
value
...
...
tilelang/carver/template/base.py
View file @
29051439
...
...
@@ -2,7 +2,12 @@
from
abc
import
ABC
,
abstractmethod
# For defining abstract base classes
from
dataclasses
import
dataclass
,
field
# For defining data classes
from
..arch
import
(
# Import architecture-related utilities and classes
TileDevice
,
is_volta_arch
,
is_ampere_arch
,
is_cdna_arch
,
auto_infer_current_arch
)
TileDevice
,
is_volta_arch
,
is_ampere_arch
,
is_cdna_arch
,
auto_infer_current_arch
,
)
from
..roller.hint
import
Hint
# Import the Hint class
from
..roller.node
import
OutputNode
# Import the OutputNode class
from
tvm.tir
import
PrimFunc
# Import PrimFunc for handling tensor IR functions
...
...
@@ -41,7 +46,7 @@ class BaseTemplate(ABC):
"""
pass
def
with_arch
(
self
,
arch
:
TileDevice
)
->
'
BaseTemplate
'
:
def
with_arch
(
self
,
arch
:
TileDevice
)
->
"
BaseTemplate
"
:
"""
Sets the architecture for this template and returns itself.
...
...
@@ -109,7 +114,7 @@ class BaseTemplate(ABC):
"""
raise
NotImplementedError
(
"initialize_function is not implemented"
)
def
set_function
(
self
,
func
:
PrimFunc
)
->
'
BaseTemplate
'
:
def
set_function
(
self
,
func
:
PrimFunc
)
->
"
BaseTemplate
"
:
"""
Sets the function for this template and returns itself.
...
...
@@ -122,7 +127,7 @@ class BaseTemplate(ABC):
self
.
_func
=
func
return
self
def
set_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
])
->
'
BaseTemplate
'
:
def
set_output_nodes
(
self
,
output_nodes
:
list
[
OutputNode
])
->
"
BaseTemplate
"
:
"""
Sets the output nodes for this template and returns itself.
...
...
tilelang/carver/template/conv.py
View file @
29051439
...
...
@@ -28,6 +28,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype (str): Data type used for accumulation.
with_bias (bool): Whether to add a bias term.
"""
# Operation-related configuration parameters
N
:
int
# The number of input samples processed simultaneously in a batch.
C
:
int
# The number of input feature maps.
...
...
@@ -69,12 +70,18 @@ class ConvTemplate(BaseTemplate):
AssertionError: If N, C, H, W, F, K, S, D, P are not positive integers.
"""
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
=
self
.
N
,
self
.
C
,
self
.
H
,
self
.
W
,
self
.
F
,
self
.
K
,
self
.
S
,
self
.
D
,
self
.
P
assert
(
isinstance
(
N
,
int
)
and
isinstance
(
C
,
int
)
and
isinstance
(
H
,
int
)
and
isinstance
(
W
,
int
)
and
isinstance
(
F
,
int
)
and
isinstance
(
K
,
int
)
and
isinstance
(
S
,
int
)
and
isinstance
(
D
,
int
)
and
isinstance
(
P
,
int
)),
"Only Support Integer Params"
assert
(
N
>
0
and
C
>
0
and
H
>
0
and
W
>
0
and
F
>
0
and
K
>
0
and
S
>
0
and
D
>
0
and
P
>
0
),
"Params should be positive"
assert
(
isinstance
(
N
,
int
)
and
isinstance
(
C
,
int
)
and
isinstance
(
H
,
int
)
and
isinstance
(
W
,
int
)
and
isinstance
(
F
,
int
)
and
isinstance
(
K
,
int
)
and
isinstance
(
S
,
int
)
and
isinstance
(
D
,
int
)
and
isinstance
(
P
,
int
)
),
"Only Support Integer Params"
assert
N
>
0
and
C
>
0
and
H
>
0
and
W
>
0
and
F
>
0
and
K
>
0
and
S
>
0
and
D
>
0
and
P
>
0
,
"Params should be positive"
# Load configuration parameters
in_dtype
,
out_dtype
,
accum_dtype
=
self
.
in_dtype
,
self
.
out_dtype
,
self
.
accum_dtype
...
...
@@ -123,8 +130,10 @@ class ConvTemplate(BaseTemplate):
te
.
if_then_else
(
te
.
all
(
h_in
>=
0
,
h_in
<
H
,
w_in
>=
0
,
w_in
<
W
),
A
[
n
,
h_in
,
w_in
,
c
].
astype
(
accum_dtype
)
*
B
[
kh
,
kw
,
c
,
f
].
astype
(
accum_dtype
),
tir
.
const
(
0
,
accum_dtype
)),
axis
=
[
kh
,
kw
,
c
])
tir
.
const
(
0
,
accum_dtype
),
),
axis
=
[
kh
,
kw
,
c
],
)
# Compute convolution result
C
=
te
.
compute
(
...
...
tilelang/carver/template/flashattention.py
View file @
29051439
...
...
@@ -9,7 +9,6 @@ from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_
@
dataclass
class
FlashAttentionTemplate
(
BaseTemplate
):
_output_nodes
:
list
[
OutputNode
]
=
None
# Operation-related configuration parameters
...
...
@@ -91,10 +90,7 @@ class FlashAttentionTemplate(BaseTemplate):
"""
A_indices
=
[
b
,
i
,
k
]
B_indices
=
[
b
,
j
,
k
]
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
C
=
te
.
compute
(
...
...
tilelang/carver/template/gemv.py
View file @
29051439
...
...
@@ -50,9 +50,8 @@ class GEMVTemplate(BaseTemplate):
N
,
K
=
self
.
N
,
self
.
K
# Ensure M, N, K are valid positive integers
assert
(
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
)),
"Only Support Integer M, N, K"
assert
(
M
>
0
and
N
>
0
and
K
>
0
),
"M, N, K should be positive"
assert
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Only Support Integer M, N, K"
assert
M
>
0
and
N
>
0
and
K
>
0
,
"M, N, K should be positive"
# Load configuration parameters
trans_B
=
self
.
trans_B
...
...
@@ -86,9 +85,7 @@ class GEMVTemplate(BaseTemplate):
"""
A_indices
=
[
i
,
k
]
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
C
=
te
.
compute
(
...
...
tilelang/carver/template/general_reduce.py
View file @
29051439
...
...
@@ -9,15 +9,13 @@ from ..utils import get_roller_hints_from_func
@
dataclass
class
GeneralReductionTemplate
(
BaseTemplate
):
# OP Related Config
structure
:
str
|
list
[
str
]
=
None
shape
:
list
[
int
]
=
None
dtype
:
str
=
"float16"
def
get_hardware_aware_configs
(
self
,
arch
:
TileDevice
=
None
,
topk
:
int
=
10
)
->
list
[
Hint
]:
roller_hints
=
get_roller_hints_from_func
(
self
.
_func
,
arch
=
arch
,
topk
=
topk
,
allow_gemv
=
False
)
roller_hints
=
get_roller_hints_from_func
(
self
.
_func
,
arch
=
arch
,
topk
=
topk
,
allow_gemv
=
False
)
return
roller_hints
def
initialize_function
(
self
)
->
None
:
...
...
@@ -38,9 +36,9 @@ class GeneralReductionTemplate(BaseTemplate):
spatial_axes
=
[]
reduce_axes
=
[]
for
i
,
axis_type
in
enumerate
(
self
.
structure
):
if
axis_type
.
upper
()
==
'S'
:
if
axis_type
.
upper
()
==
"S"
:
spatial_axes
.
append
((
i
,
self
.
shape
[
i
]))
elif
axis_type
.
upper
()
==
'R'
:
elif
axis_type
.
upper
()
==
"R"
:
reduce_axes
.
append
((
i
,
self
.
shape
[
i
]))
else
:
raise
ValueError
(
f
"Unrecognized axis type '
{
axis_type
}
', only 'S'/'R' allowed."
)
...
...
@@ -90,7 +88,7 @@ class GeneralReductionTemplate(BaseTemplate):
# Walk through the structure in order
for
axis_type
in
self
.
structure
:
if
axis_type
.
upper
()
==
'S'
:
if
axis_type
.
upper
()
==
"S"
:
# use the next spatial_indices item
full_index
.
append
(
spatial_indices
[
spatial_iter
])
spatial_iter
+=
1
...
...
tilelang/carver/template/matmul.py
View file @
29051439
...
...
@@ -65,9 +65,8 @@ class MatmulTemplate(BaseTemplate):
M
,
N
,
K
=
self
.
M
,
self
.
N
,
self
.
K
# Ensure M, N, K are valid positive integers
assert
(
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
)),
"Only Support Integer M, N, K"
assert
(
M
>
0
and
N
>
0
and
K
>
0
),
"M, N, K should be positive"
assert
isinstance
(
M
,
int
)
and
isinstance
(
N
,
int
)
and
isinstance
(
K
,
int
),
"Only Support Integer M, N, K"
assert
M
>
0
and
N
>
0
and
K
>
0
,
"M, N, K should be positive"
# Load configuration parameters
trans_A
,
trans_B
=
self
.
trans_A
,
self
.
trans_B
...
...
@@ -101,9 +100,7 @@ class MatmulTemplate(BaseTemplate):
"""
A_indices
=
[
i
,
k
]
if
not
trans_A
else
[
k
,
i
]
# Adjust indexing if A is transposed
B_indices
=
[
k
,
j
]
if
not
trans_B
else
[
j
,
k
]
# Adjust indexing if B is transposed
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
return
te
.
sum
(
A
[
tuple
(
A_indices
)].
astype
(
accum_dtype
)
*
B
[
tuple
(
B_indices
)].
astype
(
accum_dtype
),
axis
=
k
)
# Compute matrix multiplication result
C
=
te
.
compute
(
...
...
tilelang/carver/utils.py
View file @
29051439
...
...
@@ -26,11 +26,9 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
"""
def
get_roller_hints_from_func
(
func_or_module
:
tir
.
PrimFunc
|
IRModule
,
arch
:
TileDevice
,
topk
:
int
=
10
,
tensorcore_only
:
bool
=
False
,
allow_gemv
:
bool
=
False
)
->
list
[
Hint
]
|
None
:
def
get_roller_hints_from_func
(
func_or_module
:
tir
.
PrimFunc
|
IRModule
,
arch
:
TileDevice
,
topk
:
int
=
10
,
tensorcore_only
:
bool
=
False
,
allow_gemv
:
bool
=
False
)
->
list
[
Hint
]
|
None
:
func
=
None
if
isinstance
(
func_or_module
,
tir
.
PrimFunc
):
func
=
func_or_module
...
...
@@ -44,8 +42,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
roller_hints
=
None
if
tensorcore_only
:
try
:
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
except
Exception
as
e_msg
:
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
tags
=
None
...
...
@@ -58,8 +55,7 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
policy
=
DefaultPolicy
.
from_prim_func
(
func
=
func
,
arch
=
arch
)
tensorized_func
=
None
try
:
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
tensorized_func
,
tags
=
get_tensorized_func_and_tags
(
func
,
arch
.
target
,
allow_gemv
=
allow_gemv
)
except
Exception
as
e_msg
:
logger
.
debug
(
"Get tensorized func and tags failed: "
,
e_msg
)
tags
=
None
...
...
@@ -69,10 +65,9 @@ def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
return
roller_hints
def
get_roller_hints_from_output_nodes
(
output_nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
topk
:
int
=
10
,
extra_tags
:
list
[
str
]
|
None
=
None
)
->
list
[
Hint
]
|
None
:
def
get_roller_hints_from_output_nodes
(
output_nodes
:
list
[
OutputNode
],
arch
:
TileDevice
,
topk
:
int
=
10
,
extra_tags
:
list
[
str
]
|
None
=
None
)
->
list
[
Hint
]
|
None
:
assert
isinstance
(
output_nodes
,
list
),
"The input should be a list of functions."
lints
=
[]
...
...
@@ -80,8 +75,7 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
policy
=
TensorCorePolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
lints
=
policy
.
emit_config
(
topk
)
except
Exception
as
e_msg
:
logger
.
debug
(
f
"Generate hints from output nodes failed:
{
e_msg
}
"
,
"fallback to default policy"
)
logger
.
debug
(
f
"Generate hints from output nodes failed:
{
e_msg
}
"
,
"fallback to default policy"
)
if
len
(
lints
)
==
0
:
policy
=
DefaultPolicy
.
from_output_nodes
(
output_nodes
,
arch
=
arch
,
tags
=
None
)
...
...
@@ -92,7 +86,6 @@ def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
def
retrieve_func_from_module
(
ir_module
:
IRModule
)
->
PrimFunc
:
if
not
isinstance
(
ir_module
,
IRModule
):
raise
ValueError
(
"Not supported type: "
,
type
(
ir_module
))
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
(
"The optimized module should only have one global variable for default schedule."
)
assert
len
(
ir_module
.
get_global_vars
())
==
1
,
"The optimized module should only have one global variable for default schedule."
func
=
list
(
ir_module
.
functions
.
values
())[
0
]
return
func
tilelang/contrib/cc.py
View file @
29051439
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Util to invoke C/C++ compilers in the system."""
import
functools
import
os
import
shutil
...
...
@@ -30,8 +31,7 @@ from tvm.contrib import utils as _utils
def
_is_linux_like
():
return
(
sys
.
platform
==
"darwin"
or
sys
.
platform
.
startswith
(
"linux"
)
or
sys
.
platform
.
startswith
(
"freebsd"
))
return
sys
.
platform
==
"darwin"
or
sys
.
platform
.
startswith
(
"linux"
)
or
sys
.
platform
.
startswith
(
"freebsd"
)
def
_is_windows_like
():
...
...
@@ -90,7 +90,7 @@ def get_cplus_compiler():
def
is_darwin
():
return
platform
.
system
()
==
'
Darwin
'
return
platform
.
system
()
==
"
Darwin
"
def
create_shared
(
output
,
objects
,
options
=
None
,
cc
=
None
,
cwd
=
None
,
ccache_env
=
None
):
...
...
@@ -287,11 +287,7 @@ create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared
.
get_target_triple
=
get_target_by_dump_machine
(
os
.
environ
.
get
(
"CXX"
,
get_cc
()))
def
cross_compiler
(
compile_func
,
options
=
None
,
output_format
=
None
,
get_target_triple
=
None
,
add_files
=
None
):
def
cross_compiler
(
compile_func
,
options
=
None
,
output_format
=
None
,
get_target_triple
=
None
,
add_files
=
None
):
"""Create a cross compiler function by specializing compile_func with options.
This function can be used to construct compile functions that
...
...
@@ -363,13 +359,7 @@ def cross_compiler(compile_func,
return
_fcompile
def
_linux_compile
(
output
,
objects
,
options
,
compile_cmd
,
cwd
=
None
,
ccache_env
=
None
,
compile_shared
=
False
):
def
_linux_compile
(
output
,
objects
,
options
,
compile_cmd
,
cwd
=
None
,
ccache_env
=
None
,
compile_shared
=
False
):
cmd
=
[
compile_cmd
]
if
compile_cmd
!=
"nvcc"
:
if
compile_shared
or
output
.
endswith
(
".so"
)
or
output
.
endswith
(
".dylib"
):
...
...
@@ -430,15 +420,15 @@ def _windows_compile(output, objects, options, cwd=None, ccache_env=None):
raise
ValueError
(
"ccache not found"
)
try
:
proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
cwd
=
cwd
,
env
=
env
)
proc
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
STDOUT
,
cwd
=
cwd
,
env
=
env
)
(
out
,
_
)
=
proc
.
communicate
()
except
FileNotFoundError
:
raise
RuntimeError
(
"Can not find the LLVM clang for Windows clang.exe)."
raise
RuntimeError
(
"Can not find the LLVM clang for Windows clang.exe)."
"Make sure it's installed"
" and the installation directory is in the %PATH% environment "
"variable. Prebuilt binaries can be found at: https://llvm.org/"
)
\
from
None
"variable. Prebuilt binaries can be found at: https://llvm.org/"
)
from
None
if
proc
.
returncode
!=
0
:
msg
=
"Compilation error:
\n
"
msg
+=
" "
.
join
(
cmd
)
+
"
\n
"
...
...
tilelang/contrib/dlpack.py
View file @
29051439
...
...
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from
tvm
import
runtime
...
...
@@ -45,12 +46,8 @@ def convert_func(tvm_func, tensor_type, to_dlpack_func):
def
adapt_tensor
(
arg
):
if
isinstance
(
arg
,
tensor_type
):
if
arg
.
dtype
in
{
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
}:
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
if
arg
.
dtype
in
{
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
}:
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
return
runtime
.
from_dlpack
(
to_dlpack_func
(
arg
))
return
arg
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment