Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
351 additions
and
30 deletions
+351
-30
tilelang/transform/__init__.py
tilelang/transform/__init__.py
+11
-12
tilelang/transform/_ffi_api.py
tilelang/transform/_ffi_api.py
+2
-2
tilelang/utils/__init__.py
tilelang/utils/__init__.py
+6
-0
tilelang/utils/language.py
tilelang/utils/language.py
+310
-11
tilelang/utils/tensor.py
tilelang/utils/tensor.py
+3
-3
version_provider.py
version_provider.py
+19
-2
No files found.
tilelang/transform/__init__.py
View file @
bbbf4207
...
...
@@ -80,6 +80,17 @@ def FrontendLegalize():
return
_ffi_api
.
FrontendLegalize
()
# type: ignore
def
LegalizeNegativeIndex
():
"""Legalize negative indices in buffer loads.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return
_ffi_api
.
LegalizeNegativeIndex
()
# type: ignore
def
InjectAssumes
():
"""Inject Assumes
...
...
@@ -330,18 +341,6 @@ def LowerDeviceStorageAccessInfo():
return
_ffi_api
.
LowerDeviceStorageAccessInfo
()
# type: ignore
def
LoopVectorizeDynamic
():
"""Try to vectorize loop with dynamic shape.
Returns
-------
fpass : tvm.transform.Pass
The result pass
----
"""
return
_ffi_api
.
LoopVectorizeDynamic
()
# type: ignore
def
ConfigIndexBitwidth
():
"""Config index bitwidth.
...
...
tilelang/transform/_ffi_api.py
View file @
bbbf4207
"""FFI APIs for tilelang"""
import
tvm
.
ffi
import
tvm
_
ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm
.
ffi
.
_
init_api
(
"tl.transform"
,
__name__
)
# pylint: disable=protected-access
tvm
_
ffi
.
init_
ffi_
api
(
"tl.transform"
,
__name__
)
tilelang/utils/__init__.py
View file @
bbbf4207
...
...
@@ -6,8 +6,14 @@ from .language import (
is_global
,
# noqa: F401
is_shared
,
# noqa: F401
is_shared_dynamic
,
# noqa: F401
is_tensor_memory
,
# noqa: F401
is_fragment
,
# noqa: F401
is_local
,
# noqa: F401
array_reduce
,
# noqa: F401
retrieve_stride
,
# noqa: F401
retrieve_shape
,
# noqa: F401
retrive_ptr_from_buffer_region
,
# noqa: F401
is_full_region
,
# noqa: F401
to_buffer_region
,
# noqa: F401
)
from
.deprecated
import
deprecated
# noqa: F401
tilelang/utils/language.py
View file @
bbbf4207
from
__future__
import
annotations
from
tvm.tir
import
Buffer
from
tvm.tir
import
Buffer
,
BufferLoad
,
BufferRegion
,
PrimExpr
from
functools
import
reduce
from
tvm
import
IRModule
from
tvm.tir
import
PrimFunc
...
...
@@ -9,29 +9,50 @@ from tvm import ir, tir
# These utility functions check the memory scope of a given TVM buffer.
def
is_global
(
buffer
:
Buffer
)
->
bool
:
def
_get_buffer
(
buffer_or_load_or_region
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
Buffer
:
"""
Extract Buffer from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
Buffer: The underlying buffer object
"""
if
isinstance
(
buffer_or_load_or_region
,
Buffer
):
return
buffer_or_load_or_region
elif
isinstance
(
buffer_or_load_or_region
,
(
tir
.
BufferLoad
,
tir
.
BufferRegion
)):
return
buffer_or_load_or_region
.
buffer
else
:
raise
TypeError
(
f
"Expected Buffer, BufferLoad, or BufferRegion, got
{
type
(
buffer_or_load_or_region
)
}
"
)
def
is_global
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
bool
:
"""
Check if the buffer is in the global memory scope.
Args:
buffer
(Buffer)
: The TVM buffer to check.
buffer: The TVM buffer
, BufferLoad, or BufferRegion
to check.
Returns:
bool: True if the buffer is in global memory, False otherwise.
"""
buffer
=
_get_buffer
(
buffer
)
return
buffer
.
scope
()
==
"global"
def
is_shared
(
buffer
:
Buffer
,
allow_dynamic
:
bool
=
True
)
->
bool
:
def
is_shared
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
,
allow_dynamic
:
bool
=
True
)
->
bool
:
"""
Check if the buffer is in the shared memory scope.
Args:
buffer
(Buffer)
: The TVM buffer to check.
buffer: The TVM buffer
, BufferLoad, or BufferRegion
to check.
Returns:
bool: True if the buffer is in shared memory, False otherwise.
"""
buffer
=
_get_buffer
(
buffer
)
conditions
=
[
False
]
conditions
.
append
(
buffer
.
scope
()
==
"shared"
)
if
allow_dynamic
:
...
...
@@ -39,42 +60,59 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
return
any
(
conditions
)
def
is_shared_dynamic
(
buffer
:
Buffer
)
->
bool
:
def
is_shared_dynamic
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
bool
:
"""
Check if the buffer is in the dynamic shared memory scope.
Args:
buffer
(Buffer)
: The TVM buffer to check.
buffer: The TVM buffer
, BufferLoad, or BufferRegion
to check.
Returns:
bool: True if the buffer is in dynamic shared memory, False otherwise.
"""
buffer
=
_get_buffer
(
buffer
)
return
buffer
.
scope
()
==
"shared.dyn"
def
is_local
(
buffer
:
Buffer
)
->
bool
:
def
is_tensor_memory
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
bool
:
"""
Check if the buffer is in tensor memory scope (e.g., shared.tmem).
Args:
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is in tensor memory, False otherwise.
"""
buffer
=
_get_buffer
(
buffer
)
return
buffer
.
scope
().
startswith
(
"shared.tmem"
)
def
is_local
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
bool
:
"""
Check if the buffer is in the local memory scope.
Args:
buffer
(Buffer)
: The TVM buffer to check.
buffer: The TVM buffer
, BufferLoad, or BufferRegion
to check.
Returns:
bool: True if the buffer is in local memory, False otherwise.
"""
buffer
=
_get_buffer
(
buffer
)
return
buffer
.
scope
()
==
"local"
def
is_fragment
(
buffer
:
Buffer
)
->
bool
:
def
is_fragment
(
buffer
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
bool
:
"""
Check if the buffer is a fragment (e.g., for matrix multiplication operations).
Args:
buffer
(Buffer)
: The TVM buffer to check.
buffer: The TVM buffer
, BufferLoad, or BufferRegion
to check.
Returns:
bool: True if the buffer is a fragment, False otherwise.
"""
buffer
=
_get_buffer
(
buffer
)
return
buffer
.
scope
().
startswith
(
"local.fragment"
)
...
...
@@ -144,3 +182,264 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion
return
tir
.
BufferRegion
(
buffer
,
regions
)
else
:
return
None
def
to_buffer_region
(
obj
:
Buffer
|
BufferLoad
|
BufferRegion
)
->
BufferRegion
:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if
isinstance
(
obj
,
tir
.
BufferRegion
):
return
obj
if
isinstance
(
obj
,
tir
.
Buffer
):
mins
=
[
tir
.
IntImm
(
"int32"
,
0
)
for
_
in
obj
.
shape
]
ranges
=
[
ir
.
Range
.
from_min_extent
(
m
,
e
)
for
m
,
e
in
zip
(
mins
,
obj
.
shape
)]
return
tir
.
BufferRegion
(
obj
,
ranges
)
if
isinstance
(
obj
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
obj
)
if
region
is
not
None
:
return
region
# Fallback: scalar load -> 1-sized ranges at indices
mins
=
[
idx
for
idx
in
obj
.
indices
]
ones
=
[
tir
.
IntImm
(
"int32"
,
1
)
for
_
in
obj
.
indices
]
ranges
=
[
ir
.
Range
.
from_min_extent
(
m
,
e
)
for
m
,
e
in
zip
(
mins
,
ones
)]
return
tir
.
BufferRegion
(
obj
.
buffer
,
ranges
)
raise
ValueError
(
f
"Unsupported argument type for BufferRegion:
{
type
(
obj
)
}
"
)
def
retrieve_shape
(
obj
:
Buffer
|
BufferRegion
|
BufferLoad
)
->
list
:
"""
Retrieve shape-like extents for a buffer-like object.
- Buffer -> its `shape`
- BufferRegion -> list of each range's `extent`
- BufferLoad -> extents from `get_buffer_region_from_load(obj)`
"""
if
isinstance
(
obj
,
tir
.
Buffer
):
return
obj
.
shape
if
isinstance
(
obj
,
tir
.
BufferRegion
):
return
[
r
.
extent
for
r
in
obj
.
region
]
if
isinstance
(
obj
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
obj
)
if
region
is
None
:
raise
ValueError
(
"Cannot retrieve shape from scalar BufferLoad without region"
)
return
[
r
.
extent
for
r
in
region
.
region
]
raise
ValueError
(
f
"Unsupported retrieve_shape argument type:
{
type
(
obj
)
}
for object
{
obj
}
"
)
def
retrieve_stride
(
obj
:
Buffer
|
BufferRegion
|
BufferLoad
)
->
list
:
"""
Retrieve row-major strides for a buffer-like object based on its buffer.shape.
For BufferRegion and BufferLoad, uses the underlying buffer's `shape`.
"""
if
isinstance
(
obj
,
tir
.
Buffer
):
shape
=
obj
.
shape
elif
isinstance
(
obj
,
(
tir
.
BufferRegion
,
tir
.
BufferLoad
)):
shape
=
obj
.
buffer
.
shape
else
:
raise
ValueError
(
f
"Unsupported retrieve_stride argument type:
{
type
(
obj
)
}
for object
{
obj
}
"
)
strides
=
[]
stride
=
1
for
s
in
reversed
(
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
def
retrive_ptr_from_buffer_region
(
buffer_or_load_or_region
:
Buffer
|
BufferLoad
|
BufferRegion
,
access_type
:
str
=
"r"
)
->
PrimExpr
:
if
isinstance
(
buffer_or_load_or_region
,
Buffer
):
return
buffer_or_load_or_region
.
access_ptr
(
access_type
)
elif
isinstance
(
buffer_or_load_or_region
,
BufferLoad
):
buffer_load
=
buffer_or_load_or_region
offset
,
stride
=
0
,
1
buffer
=
buffer_load
.
buffer
for
i
,
shape
in
enumerate
(
reversed
(
buffer
.
shape
)):
indice
=
buffer_load
.
indices
[
len
(
buffer_load
.
indices
)
-
i
-
1
]
if
isinstance
(
indice
,
(
tir
.
IntImm
,
tir
.
PrimExpr
)):
offset
+=
indice
*
stride
elif
isinstance
(
indice
,
tir
.
Ramp
):
offset
+=
indice
.
base
*
stride
else
:
raise
ValueError
(
f
"Unsupported index type:
{
type
(
indice
)
}
"
)
stride
*=
shape
return
buffer
.
access_ptr
(
access_type
,
offset
=
offset
)
elif
isinstance
(
buffer_or_load_or_region
,
BufferRegion
):
buffer_region
=
buffer_or_load_or_region
buffer
=
buffer_region
.
buffer
offset
,
stride
=
0
,
1
for
i
,
shape
in
enumerate
(
reversed
(
buffer
.
shape
)):
offset
+=
buffer_region
.
region
[
len
(
buffer_region
.
region
)
-
i
-
1
].
min
*
stride
stride
*=
shape
return
buffer
.
access_ptr
(
access_type
,
offset
=
offset
)
else
:
raise
ValueError
(
f
"Unsupported buffer type:
{
type
(
buffer_or_load_or_region
)
}
"
)
def
retrieve_ptr
(
obj
:
Buffer
|
BufferRegion
|
BufferLoad
,
access_type
:
str
=
"r"
,
ignore_last_ndim
:
int
=
0
,
)
->
PrimExpr
:
"""
Retrieve a pointer to the start of a (possibly sliced) buffer region.
- Buffer -> base pointer
- BufferRegion -> pointer with byte offset computed from region minima
- BufferLoad -> pointer offset computed from indices or derived region
Args:
obj: Buffer-like object
access_type: TVM Buffer access mask, e.g. "r", "w", "rw"
ignore_last_ndim: do not offset the last N dimensions
"""
if
isinstance
(
obj
,
tir
.
Buffer
):
return
obj
.
access_ptr
(
access_type
)
if
isinstance
(
obj
,
tir
.
BufferRegion
):
buffer
,
region
=
obj
.
buffer
,
obj
.
region
strides
=
retrieve_stride
(
obj
)
# offset only over the leading dims, optionally ignoring the tail dims
upto
=
max
(
0
,
len
(
region
)
-
int
(
ignore_last_ndim
))
offset
=
0
for
i
in
range
(
upto
):
offset
+=
region
[
i
].
min
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_type
,
offset
=
offset
)
if
isinstance
(
obj
,
tir
.
BufferLoad
):
buffer
=
obj
.
buffer
region
=
get_buffer_region_from_load
(
obj
)
if
region
is
not
None
:
mins
=
[
r
.
min
for
r
in
region
.
region
]
else
:
mins
=
list
(
obj
.
indices
)
strides
=
retrieve_stride
(
obj
)
upto
=
max
(
0
,
len
(
mins
)
-
int
(
ignore_last_ndim
))
offset
=
0
for
i
in
range
(
upto
):
offset
+=
mins
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_type
,
offset
=
offset
)
raise
ValueError
(
f
"Unsupported retrieve_ptr argument type:
{
type
(
obj
)
}
for object
{
obj
}
"
)
def
retrieve_offset
(
obj
:
Buffer
|
BufferRegion
|
BufferLoad
)
->
list
:
"""
Retrieve per-dimension minima offsets.
- Buffer -> [0, 0, ...]
- BufferRegion -> [r.min for r in region]
- BufferLoad -> indices (or derived region minima)
"""
if
isinstance
(
obj
,
tir
.
Buffer
):
return
[
0
]
*
len
(
obj
.
shape
)
if
isinstance
(
obj
,
tir
.
BufferRegion
):
return
[
r
.
min
for
r
in
obj
.
region
]
if
isinstance
(
obj
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
obj
)
if
region
is
not
None
:
return
[
r
.
min
for
r
in
region
.
region
]
return
list
(
obj
.
indices
)
raise
ValueError
(
f
"Unsupported retrieve_offset argument type:
{
type
(
obj
)
}
for object
{
obj
}
"
)
def
prim_expr_equal
(
lhs
,
rhs
)
->
bool
:
"""
Robust equality for PrimExpr shapes/extents.
Tries structural_equal first, then falls back to expr_deep_equal.
Python ints are converted to IntImm for comparison.
"""
if
isinstance
(
lhs
,
int
)
and
isinstance
(
rhs
,
int
):
return
lhs
==
rhs
if
isinstance
(
lhs
,
int
):
lhs
=
tir
.
IntImm
(
"int32"
,
lhs
)
if
isinstance
(
rhs
,
int
):
rhs
=
tir
.
IntImm
(
"int32"
,
rhs
)
if
ir
.
structural_equal
(
lhs
,
rhs
):
return
True
return
tir
.
analysis
.
expr_deep_equal
(
lhs
,
rhs
)
def
legalize_pairwise_extents
(
src_extents
:
list
,
dst_extents
:
list
)
->
tuple
[
list
,
list
]:
"""
Right-align and broadcast two extent lists to be mutually compatible.
Early-exit rule:
- If the number of non-1 dimensions in `src_extents` equals that in `dst_extents`,
no adjustment is made; the original extents are returned unchanged. This
preserves the per-dimension iteration mapping (one loop var per non-1 dim)
and avoids creating extra varying axes on either side.
Otherwise, for each pair of tail-aligned dimensions (x, y):
- if x == y: keep both
- elif x == 1: set x = y
- elif y == 1: set y = x
- else: promote both to tir.max(x, y) to handle dynamic-vs-static safely
Leading unmatched dimensions are kept as-is.
Returns a tuple of new lists (src_new, dst_new).
"""
a
=
list
(
src_extents
)
b
=
list
(
dst_extents
)
# If both sides have the same number of non-1 extents, don't re-broadcast.
def
_num_non_one
(
exts
:
list
)
->
int
:
return
sum
(
0
if
prim_expr_equal
(
x
,
1
)
else
1
for
x
in
exts
)
if
_num_non_one
(
a
)
==
_num_non_one
(
b
):
return
a
,
b
k
=
min
(
len
(
a
),
len
(
b
))
for
i
in
range
(
1
,
k
+
1
):
x
,
y
=
a
[
-
i
],
b
[
-
i
]
if
prim_expr_equal
(
x
,
y
):
continue
elif
prim_expr_equal
(
x
,
1
):
a
[
-
i
]
=
y
elif
prim_expr_equal
(
y
,
1
):
b
[
-
i
]
=
x
else
:
# Dynamic mismatch: promote to max so downstream clamping/predicates remain safe
m
=
tir
.
max
(
x
,
y
)
a
[
-
i
]
=
m
b
[
-
i
]
=
m
return
a
,
b
def
is_full_region
(
buffer_region
:
BufferRegion
)
->
bool
:
"""
Check whether a BufferRegion covers the full buffer region.
A full region means each dimension has start 0 and extent equal to
the corresponding dimension in the buffer's shape.
Args:
buffer_region: The TVM BufferRegion to check.
Returns:
bool: True if the region is full; otherwise False.
"""
if
not
isinstance
(
buffer_region
,
tir
.
BufferRegion
):
raise
TypeError
(
f
"Expected BufferRegion, got
{
type
(
buffer_region
)
}
"
)
buf
=
buffer_region
.
buffer
ranges
=
buffer_region
.
region
if
len
(
buf
.
shape
)
!=
len
(
ranges
):
return
False
expr_equal
=
tir
.
analysis
.
expr_deep_equal
for
dim
,
r
in
zip
(
buf
.
shape
,
ranges
):
# start == 0 and extent == shape
if
not
expr_equal
(
r
.
min
,
0
):
return
False
if
not
expr_equal
(
r
.
extent
,
dim
):
return
False
return
True
tilelang/utils/tensor.py
View file @
bbbf4207
...
...
@@ -2,7 +2,7 @@ from __future__ import annotations
"""The profiler and convert to torch utils"""
from
enum
import
Enum
import
torch
from
tvm
.runtime
import
ndarray
from
tvm
import
runtime
from
tvm
import
tir
from
torch.utils.dlpack
import
to_dlpack
import
numpy
as
np
...
...
@@ -49,9 +49,9 @@ def adapt_torch2tvm(arg):
if
arg
.
dtype
in
{
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
}:
return
ndarray
.
from_dlpack
(
to_dlpack
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
return
runtime
.
from_dlpack
(
to_dlpack
(
arg
.
view
(
torch
.
int8
))).
_create_view
(
shape
=
arg
.
shape
,
dtype
=
float8_dtype_map
[
arg
.
dtype
])
return
ndarray
.
from_dlpack
(
to_dlpack
(
arg
))
return
runtime
.
from_dlpack
(
to_dlpack
(
arg
))
return
arg
...
...
version_provider.py
View file @
bbbf4207
...
...
@@ -4,10 +4,17 @@ import os
import
platform
import
subprocess
from
pathlib
import
Path
from
functools
import
lru_cache
ROOT
=
Path
(
__file__
).
parent
base_version
=
(
ROOT
/
'VERSION'
).
read_text
().
strip
()
# When installing a sdist,
# the installed version needs to match the sdist version,
# so pip will complain when we install `tilelang-0.1.6.post2+gitxxxx.tar.gz`.
# To workaround that, when building sdist,
# we do not add version label and use a file to store the git hash instead.
git_pin
=
ROOT
/
'.git_commit.txt'
def
_read_cmake_bool
(
i
:
str
|
None
,
default
=
False
):
...
...
@@ -16,6 +23,7 @@ def _read_cmake_bool(i: str | None, default=False):
return
i
.
lower
()
not
in
(
'0'
,
'false'
,
'off'
,
'no'
,
'n'
,
''
)
@
lru_cache
(
maxsize
=
1
)
def
get_git_commit_id
()
->
str
|
None
:
"""Get the current git commit hash by running git in the current file's directory."""
...
...
@@ -24,9 +32,13 @@ def get_git_commit_id() -> str | None:
capture_output
=
True
,
encoding
=
'utf-8'
)
if
r
.
returncode
==
0
:
return
r
.
stdout
.
strip
()
_git
=
r
.
stdout
.
strip
()
git_pin
.
write_text
(
_git
)
return
_git
elif
git_pin
.
exists
():
return
git_pin
.
read_text
().
strip
()
else
:
return
'unknown'
return
None
def
dynamic_metadata
(
...
...
@@ -37,6 +49,9 @@ def dynamic_metadata(
version
=
base_version
# generate git version for sdist
get_git_commit_id
()
if
not
_read_cmake_bool
(
os
.
environ
.
get
(
'NO_VERSION_LABEL'
)):
exts
=
[]
backend
=
None
...
...
@@ -66,6 +81,8 @@ def dynamic_metadata(
pass
elif
git_hash
:
=
get_git_commit_id
():
exts
.
append
(
f
'git
{
git_hash
[:
8
]
}
'
)
else
:
exts
.
append
(
'gitunknown'
)
if
exts
:
version
+=
'+'
+
'.'
.
join
(
exts
)
...
...
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment