Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
281 additions
and
328 deletions
+281
-328
tilelang/language/ast/_ffi_api.py
tilelang/language/ast/_ffi_api.py
+1
-0
tilelang/language/ast/ir.py
tilelang/language/ast/ir.py
+39
-48
tilelang/language/atomic.py
tilelang/language/atomic.py
+7
-19
tilelang/language/builtin.py
tilelang/language/builtin.py
+48
-62
tilelang/language/copy.py
tilelang/language/copy.py
+36
-22
tilelang/language/customize.py
tilelang/language/customize.py
+6
-6
tilelang/language/experimental/gemm_sp.py
tilelang/language/experimental/gemm_sp.py
+8
-5
tilelang/language/fill.py
tilelang/language/fill.py
+3
-4
tilelang/language/frame.py
tilelang/language/frame.py
+4
-4
tilelang/language/gemm.py
tilelang/language/gemm.py
+30
-6
tilelang/language/kernel.py
tilelang/language/kernel.py
+10
-18
tilelang/language/logical.py
tilelang/language/logical.py
+3
-4
tilelang/language/loop.py
tilelang/language/loop.py
+13
-12
tilelang/language/math_intrinsics.py
tilelang/language/math_intrinsics.py
+1
-1
tilelang/language/overrides/parser.py
tilelang/language/overrides/parser.py
+19
-6
tilelang/language/parser/entry.py
tilelang/language/parser/entry.py
+3
-5
tilelang/language/parser/operation.py
tilelang/language/parser/operation.py
+5
-7
tilelang/language/parser/parser.py
tilelang/language/parser/parser.py
+4
-8
tilelang/language/print.py
tilelang/language/print.py
+11
-29
tilelang/language/proxy.py
tilelang/language/proxy.py
+30
-62
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/language/ast/_ffi_api.py
View file @
29051439
...
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""FFI APIs"""
import
tvm.ffi
tvm
.
ffi
.
_init_api
(
"script.ir_builder.tir"
,
__name__
)
# pylint: disable=protected-access
tilelang/language/ast/ir.py
View file @
29051439
...
...
@@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisSpatial
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
reduce
(
...
...
@@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisReduce
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
scan
(
...
...
@@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisScan
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
opaque
(
...
...
@@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisOpaque
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
remap
(
kinds
:
str
,
bindings
:
List
[
PrimExpr
],
dtype
:
str
=
"int32"
)
->
Union
[
List
[
Var
],
Var
]:
...
...
@@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name
The iteration variables.
"""
iter_vars
=
_ffi_api
.
AxisRemap
(
# type: ignore[attr-defined] # pylint: disable=no-member
kinds
,
bindings
,
dtype
)
kinds
,
bindings
,
dtype
)
return
iter_vars
[
0
]
if
len
(
iter_vars
)
==
1
else
iter_vars
S
=
spatial
# pylint: disable=invalid-name
R
=
reduce
# pylint: disable=invalid-name
def
serial
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
serial
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The serial For statement.
Parameters
...
...
@@ -700,10 +702,7 @@ def serial(start: PrimExpr,
return
_ffi_api
.
Serial
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
parallel
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
parallel
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The parallel For statement.
Parameters
...
...
@@ -731,10 +730,7 @@ def parallel(start: PrimExpr,
return
_ffi_api
.
Parallel
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
vectorized
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
vectorized
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The vectorized For statement.
Parameters
...
...
@@ -762,10 +758,7 @@ def vectorized(start: PrimExpr,
return
_ffi_api
.
Vectorized
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
unroll
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
unroll
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The unrolled For statement.
Parameters
...
...
@@ -837,7 +830,8 @@ def thread_binding(
else
:
start
=
0
return
_ffi_api
.
ThreadBinding
(
# type: ignore[attr-defined] # pylint: disable=no-member
start
,
stop
,
thread
,
annotations
)
start
,
stop
,
thread
,
annotations
)
def
grid
(
*
extents
:
PrimExpr
)
->
frame
.
ForFrame
:
...
...
@@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d
def
LetStmt
(
# pylint: disable=invalid-name
value
:
PrimExpr
,
type_annotation
:
Optional
[
Type
]
=
None
,
# pylint: disable=redefined-outer-name
*
,
var
:
Optional
[
Var
]
=
None
,
# pylint: disable=redefined-outer-name
value
:
PrimExpr
,
type_annotation
:
Optional
[
Type
]
=
None
,
# pylint: disable=redefined-outer-name
*
,
var
:
Optional
[
Var
]
=
None
,
# pylint: disable=redefined-outer-name
)
->
frame
.
LetFrame
:
"""Create a LetStmt binding
...
...
@@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name
def
Let
(
# pylint: disable=invalid-name
expr
:
PrimExpr
,
where
:
Dict
[
Var
,
PrimExpr
],
# pylint: disable=redefined-outer-name
expr
:
PrimExpr
,
where
:
Dict
[
Var
,
PrimExpr
],
# pylint: disable=redefined-outer-name
)
->
PrimExpr
:
"""Create a Let expression binding"""
assert
len
(
where
)
==
1
,
"T.Let only allows `where` to have exactly one element"
...
...
@@ -980,7 +974,8 @@ def realize(
The result RealizeFrame.
"""
return
_ffi_api
.
Realize
(
# type: ignore[attr-defined] # pylint: disable=no-member
buffer_slice
,
storage_scope
,
condition
)
buffer_slice
,
storage_scope
,
condition
)
def
allocate
(
...
...
@@ -1012,7 +1007,8 @@ def allocate(
if
isinstance
(
condition
,
bool
):
condition
=
IntImm
(
"bool"
,
condition
)
return
_ffi_api
.
Allocate
(
# type: ignore[attr-defined] # pylint: disable=no-member
extents
,
dtype
,
scope
,
condition
,
annotations
)
extents
,
dtype
,
scope
,
condition
,
annotations
)
def
allocate_const
(
...
...
@@ -1048,7 +1044,8 @@ def allocate_const(
np_data
=
np_data
.
reshape
(
extents
)
return
_ffi_api
.
AllocateConst
(
# type: ignore[attr-defined] # pylint: disable=no-member
ndarray
.
array
(
np_data
),
dtype
,
extents
,
annotations
)
ndarray
.
array
(
np_data
),
dtype
,
extents
,
annotations
)
def
attr
(
node
:
Any
,
attr_key
:
str
,
value
:
Union
[
PrimExpr
,
str
])
->
frame
.
AttrFrame
:
...
...
@@ -1297,7 +1294,8 @@ def buffer_store(
if
isinstance
(
value
,
bool
)
and
buffer
.
dtype
==
"bool"
:
value
=
IntImm
(
"bool"
,
value
)
return
_ffi_api
.
BufferStore
(
# type: ignore[attr-defined] # pylint: disable=no-member
buffer
,
value
,
expr_indices
)
buffer
,
value
,
expr_indices
)
def
prefetch
(
...
...
@@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return
_ffi_api
.
Boolean
(
expr
,
is_size_var
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
handle
(
dtype
:
Optional
[
str
]
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
def
handle
(
dtype
:
Optional
[
str
]
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
"""Create a TIR var that represents a pointer.
Parameters
...
...
@@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
res
=
combiner
(
*
args
)
if
not
isinstance
(
res
,
tuple
):
res
=
(
res
,)
return
CommReducer
(
args
[:
num_args
//
2
],
args
[
num_args
//
2
:],
res
,
identity
)
return
CommReducer
(
args
[:
num_args
//
2
],
args
[
num_args
//
2
:],
res
,
identity
)
def
index_map
(
...
...
@@ -1700,16 +1695,15 @@ def target(
The target.
"""
if
not
isinstance
(
target_config
,
(
str
,
dict
)):
raise
ValueError
(
f
"T.target expected a config dict or string, but got
{
type
(
target_config
)
}
"
)
raise
ValueError
(
f
"T.target expected a config dict or string, but got
{
type
(
target_config
)
}
"
)
if
host
is
not
None
and
not
isinstance
(
host
,
(
str
,
dict
,
Target
)):
raise
ValueError
(
"T.target expected the host to be "
"a config dict, string, or T.target, "
f
"but got
{
type
(
host
)
}
"
)
raise
ValueError
(
f
"T.target expected the host to be a config dict, string, or T.target, but got
{
type
(
host
)
}
"
)
if
isinstance
(
target_config
,
dict
)
and
"host"
in
target_config
and
host
is
not
None
:
raise
ValueError
(
"T.target expects to either receive the host "
"as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
raise
ValueError
(
"T.target expects to either receive the host "
"as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
return
Target
(
target_config
,
host
)
...
...
@@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name
self
.
value
=
value
def
__iter__
(
self
):
def
f
():
for
i
in
self
.
value
:
yield
meta_var
(
i
)
...
...
@@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name
def
_op_wrapper
(
func
):
@
functools
.
wraps
(
func
)
def
wrapped
(
*
args
,
**
kwargs
):
if
"dtype"
in
kwargs
:
...
...
@@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale)
def
_dtype_forward
(
func
):
@
functools
.
wraps
(
func
)
def
wrapped
(
*
args
,
**
kwargs
):
if
"dtype"
in
kwargs
:
...
...
tilelang/language/atomic.py
View file @
29051439
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""Atomic operations for tilelang."""
from
__future__
import
annotations
import
tilelang.language
as
T
...
...
@@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = {
}
def
atomic_max
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
def
atomic_max
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
...
...
@@ -64,10 +62,7 @@ def atomic_max(dst: Buffer,
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_min
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
def
atomic_min
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
"""
Atomically update the value at dst to the minimum of its current value and value.
...
...
@@ -112,11 +107,7 @@ def atomic_min(dst: Buffer,
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_add
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
,
use_tma
:
bool
=
False
)
->
PrimExpr
:
def
atomic_add
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
,
use_tma
:
bool
=
False
)
->
PrimExpr
:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
...
...
@@ -191,8 +182,7 @@ def atomic_add(dst: Buffer,
if
memory_order
is
None
:
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
)
else
:
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
if
isinstance
(
dst
,
Buffer
)
and
isinstance
(
value
,
Buffer
):
ir
.
assert_structural_equal
(
dst
.
shape
,
value
.
shape
)
...
...
@@ -208,14 +198,12 @@ def atomic_add(dst: Buffer,
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
if
return_prev
:
raise
NotImplementedError
(
"return_prev is not supported for tile-region-based atomic operations"
)
raise
NotImplementedError
(
"return_prev is not supported for tile-region-based atomic operations"
)
if
memory_order
is
None
:
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
0
)
else
:
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_addx2
(
dst
:
Buffer
,
value
:
PrimExpr
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
...
...
tilelang/language/builtin.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang
import
tvm
as
tvm
...
...
@@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int):
def
inc_max_nreg
(
reg_count
:
int
):
"""Increment the maximum number of registers to use.
"""
"""Increment the maximum number of registers to use."""
return
set_max_nreg
(
reg_count
,
1
)
def
dec_max_nreg
(
reg_count
:
int
):
"""Decrement the maximum number of registers to use.
"""
"""Decrement the maximum number of registers to use."""
return
set_max_nreg
(
reg_count
,
0
)
def
annotate_producer_reg_dealloc
(
reg_count
:
int
=
24
):
"""Annotate the producer reg dealloc.
"""
"""Annotate the producer reg dealloc."""
return
dec_max_nreg
(
reg_count
)
def
annotate_consumer_reg_alloc
(
reg_count
:
int
=
240
):
"""Annotate the consumer reg alloc.
"""
"""Annotate the consumer reg alloc."""
return
inc_max_nreg
(
reg_count
)
def
no_set_max_nreg
():
"""Disable the maximum register limit setting.
"""
"""Disable the maximum register limit setting."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.no_set_max_nreg"
))
def
disable_warp_group_reg_alloc
():
"""Disable the warp group reg alloc.
"""
"""Disable the warp group reg alloc."""
return
no_set_max_nreg
()
...
...
@@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int):
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.warpgroup_wait"
),
num_mma
)
def
get_lane_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_lane_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the logical lane index of the calling thread within a warp.
Parameters
...
...
@@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_lane_idx"
),
warp_size_expr
)
def
get_warp_idx_sync
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_warp_idx_sync
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
...
...
@@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_idx_sync"
),
warp_size_expr
)
def
get_warp_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_warp_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the canonical warp index without synchronizing the warp.
Parameters
...
...
@@ -429,8 +430,7 @@ def get_warp_group_idx(
args
.
append
(
warp_size_expr
)
if
warps_per_group_expr
is
not
None
:
if
warp_size_expr
is
None
:
raise
ValueError
(
"get_warp_group_idx expects `warp_size` when specifying "
"`warps_per_group`."
)
raise
ValueError
(
"get_warp_group_idx expects `warp_size` when specifying `warps_per_group`."
)
args
.
append
(
warps_per_group_expr
)
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_group_idx"
),
*
args
)
...
...
@@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return
tir
.
call_intrin
(
"bool"
,
tir
.
op
.
Op
.
get
(
"tl.tl_shuffle_elect"
),
thread_extent
)
def
warpgroup_fence_operand
(
buffer_or_ptr
:
tir
.
Buffer
|
PrimExpr
,
offset
:
int
|
PrimExpr
=
0
,
num_regs
:
int
|
PrimExpr
|
None
=
None
,
dtype
:
str
|
None
=
None
):
def
warpgroup_fence_operand
(
buffer_or_ptr
:
tir
.
Buffer
|
PrimExpr
,
offset
:
int
|
PrimExpr
=
0
,
num_regs
:
int
|
PrimExpr
|
None
=
None
,
dtype
:
str
|
None
=
None
):
"""Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
...
...
@@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
)
)
if
isinstance
(
buffer_or_ptr
,
tir
.
Buffer
):
data_ptr
=
buffer_or_ptr
.
data
...
...
@@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
if
isinstance
(
dim
,
tir
.
IntImm
):
total_elems
*=
int
(
dim
)
else
:
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic."
)
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic."
)
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
elif
isinstance
(
buffer_or_ptr
,
BufferRegion
):
...
...
@@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
else
:
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
...
...
@@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
)
)
else
:
data_ptr
=
buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided
...
...
@@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
except
Exception
:
inferred
=
None
if
inferred
is
None
:
raise
ValueError
(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
raise
ValueError
(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype
=
inferred
if
num_regs
is
None
:
raise
ValueError
(
"num_regs must be provided when passing a pointer expression."
)
...
...
@@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
)
)
def
wait_wgmma
(
id
:
int
):
...
...
@@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor"
,
value
,
offset
)
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
shfl_down
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
...
...
@@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down"
,
value
,
offset
)
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
shfl_up
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
...
...
@@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call)
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up"
,
value
,
offset
)
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
sync_threads
(
barrier_id
:
int
=
None
,
arrive_count
:
int
=
None
):
"""Synchronize all threads in a block.
"""
"""Synchronize all threads in a block."""
args
=
[]
if
barrier_id
is
not
None
:
args
.
append
(
barrier_id
)
...
...
@@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
def
sync_global
():
"""Synchronize all threads in the entire grid.
"""
"""Synchronize all threads in the entire grid."""
tx
,
ty
,
tz
=
get_thread_bindings
()
ex
,
ey
,
ez
=
get_block_extents
()
print
(
tx
,
ty
,
tz
,
ex
,
ey
,
ez
)
...
...
@@ -724,8 +719,7 @@ def sync_global():
def
sync_grid
():
"""Synchronize all threads in a grid.
"""
"""Synchronize all threads in a grid."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.sync_grid"
))
...
...
@@ -741,12 +735,10 @@ def initialize_wgmma_descriptor(
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
return
evaluate
(
tir
.
call_intrin
(
...
...
@@ -757,7 +749,8 @@ def initialize_wgmma_descriptor(
layout_type_
,
int
(
leading_byte_offset
),
int
(
stride_byte_offset
),
))
)
)
def
initialize_tcgen05_descriptor
(
...
...
@@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor(
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
return
evaluate
(
tir
.
call_intrin
(
...
...
@@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor(
int
(
base_offset
),
tir
.
IntImm
(
"int32"
,
1
if
leading_is_absolute
else
0
),
int
(
swizzle_mode
),
))
)
)
def
increase_descriptor_offset
(
descriptor
:
PrimExpr
,
offset
:
PrimExpr
)
->
PrimExpr
:
...
...
@@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.increase_descriptor_offset"
),
descriptor
,
offset
))
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.increase_descriptor_offset"
),
descriptor
,
offset
))
def
loop_break
():
"""Break out of the innermost loop.
"""
"""Break out of the innermost loop."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.loop_break"
))
def
cp_async_barrier_noinc
(
barrier_id
:
int
|
PrimExpr
|
tir
.
Call
):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ptx_cp_async_barrier_noinc"
),
barrier_id
)
...
...
tilelang/language/copy.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
typing
import
Literal
from
tilelang
import
language
as
T
...
...
@@ -10,11 +11,13 @@ from tilelang.utils.language import (
from
tvm
import
ir
,
tir
def
copy
(
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
dst
:
tir
.
Buffer
|
tir
.
BufferLoad
,
coalesced_width
:
int
|
None
=
None
,
disable_tma
:
bool
=
False
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
):
def
copy
(
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
dst
:
tir
.
Buffer
|
tir
.
BufferLoad
,
coalesced_width
:
int
|
None
=
None
,
disable_tma
:
bool
=
False
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
,
):
"""Copy data between memory regions.
Args:
...
...
@@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
src_extent
=
get_extent
(
src
)
dst_extent
=
get_extent
(
dst
)
# Combine the nested if statements into a single if statement as suggested by SIM102
if
(
src_extent
is
None
and
dst_extent
is
None
and
isinstance
(
src
,
tir
.
BufferLoad
)
and
isinstance
(
dst
,
tir
.
BufferLoad
)):
if
src_extent
is
None
and
dst_extent
is
None
and
isinstance
(
src
,
tir
.
BufferLoad
)
and
isinstance
(
dst
,
tir
.
BufferLoad
):
# check if the case is like this:
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
...
...
@@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
eviction_policy
=
0
else
:
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.copy"
),
src
,
dst
,
coalesced_width
,
disable_tma
,
eviction_policy
)
def
c2d_im2col
(
img
:
tir
.
Buffer
,
col
:
tir
.
Buffer
,
nhw_step
:
tir
.
PrimExpr
,
c_step
:
tir
.
PrimExpr
,
kernel
:
int
,
stride
:
int
,
dilation
:
int
,
pad
:
int
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
):
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.copy"
),
src
,
dst
,
coalesced_width
,
disable_tma
,
eviction_policy
)
def
c2d_im2col
(
img
:
tir
.
Buffer
,
col
:
tir
.
Buffer
,
nhw_step
:
tir
.
PrimExpr
,
c_step
:
tir
.
PrimExpr
,
kernel
:
int
,
stride
:
int
,
dilation
:
int
,
pad
:
int
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
,
):
"""Perform im2col transformation for 2D convolution.
Args:
...
...
@@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
img_region
=
to_buffer_region
(
img
,
access_type
=
"r"
)
col_region
=
to_buffer_region
(
col
,
access_type
=
"w"
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.c2d_im2col"
),
img_region
,
col_region
,
nhw_step
,
c_step
,
kernel
,
stride
,
dilation
,
pad
,
eviction_policy
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.c2d_im2col"
),
img_region
,
col_region
,
nhw_step
,
c_step
,
kernel
,
stride
,
dilation
,
pad
,
eviction_policy
,
)
tilelang/language/customize.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
import
tilelang.language
as
T
from
tvm.tir
import
PrimExpr
,
Buffer
,
op
from
tilelang.utils.language
import
(
bits_product
,
prim_expr_equal
)
from
tilelang.utils.language
import
bits_product
,
prim_expr_equal
from
.atomic
import
atomic_max
,
atomic_min
,
atomic_add
,
atomic_addx2
,
atomic_addx4
,
atomic_load
,
atomic_store
# noqa: F401
...
...
@@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns:
Buffer: A new buffer view with the specified shape
"""
assert
prim_expr_equal
(
bits_product
(
shape
,
src
.
dty
pe
)
,
bits_product
(
src
.
shape
,
src
.
dtype
)
)
,
f
"T.reshape/view shape check failed. src
{
src
}
src.shape:
{
src
.
shape
}
, src.dtype:
{
src
.
dtype
}
, target shape:
{
shape
}
, target dtype:
{
src
.
dtype
}
"
assert
prim_expr_equal
(
bits_product
(
shape
,
src
.
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
(
f
"T.reshape/view shape check failed. src
{
src
}
src.
shape
:
{
src
.
sha
pe
}
,
src.dtype:
{
src
.
dtype
}
, target shape:
{
shape
}
, target dtype:
{
src
.
dtype
}
"
)
return
T
.
Tensor
(
shape
,
src
.
dtype
,
src
.
data
)
...
...
@@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N
shape
=
src
.
shape
if
dtype
is
None
:
dtype
=
src
.
dtype
assert
prim_expr_equal
(
bits_product
(
shape
,
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
"T.reshape/view shape check failed."
assert
prim_expr_equal
(
bits_product
(
shape
,
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
"T.reshape/view shape check failed."
return
T
.
Tensor
(
shape
,
dtype
,
src
.
data
)
...
...
tilelang/language/experimental/gemm_sp.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
import
tilelang.language
as
T
...
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal
,
)
from
tilelang.language.utils
import
(
buffer_region_to_tile_region
,)
buffer_region_to_tile_region
,
)
def
gemm_sp
(
...
...
@@ -169,18 +171,19 @@ def gemm_sp_v2(
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
if
len
(
A_shape
)
>
2
:
for
i
in
range
(
len
(
A_shape
)
-
2
):
assert
A_shape
[
i
]
==
1
,
\
assert
A_shape
[
i
]
==
1
,
(
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if
len
(
B_shape
)
>
2
:
for
i
in
range
(
len
(
B_shape
)
-
2
):
assert
B_shape
[
i
]
==
1
,
\
assert
B_shape
[
i
]
==
1
,
(
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M
,
N
=
C_shape
K
=
2
*
(
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
])
K_B
=
B_shape
[
-
1
]
if
transpose_B
else
B_shape
[
-
2
]
assert
prim_expr_equal
(
K
,
K_B
),
f
"T.gemm_sp K shape check failed: K_A (wo sparse) =
{
K
}
, K_B =
{
K_B
}
"
assert
prim_expr_equal
(
K
,
K_B
),
f
"T.gemm_sp K shape check failed: K_A (wo sparse) =
{
K
}
, K_B =
{
K_B
}
"
stride_a
=
A_stride
[
-
2
]
stride_b
=
B_stride
[
-
2
]
...
...
tilelang/language/fill.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tvm
import
tir
from
tilelang.language
import
has_let_value
,
get_let_value
...
...
@@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
extents
=
[
tir
.
IntImm
(
"int32"
,
1
)
for
_
in
buffer
.
indices
]
else
:
extents
=
[]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.fill"
),
to_buffer_region
(
buffer
,
access_type
=
"w"
,
extents
=
extents
),
value
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.fill"
),
to_buffer_region
(
buffer
,
access_type
=
"w"
,
extents
=
extents
),
value
)
def
clear
(
buffer
:
tir
.
Buffer
|
tir
.
Var
):
...
...
@@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var):
elif
isinstance
(
buffer_region
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
buffer_region
)
if
region
is
None
:
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
return
fill
(
region
,
0
)
else
:
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
...
...
tilelang/language/frame.py
View file @
29051439
"""Override the LetFrame to print a message when entering the frame."""
from
__future__
import
annotations
from
tvm.ffi
import
register_object
as
_register_object
from
tvm.tir
import
Var
,
PrimExpr
,
BufferLoad
,
BufferRegion
...
...
@@ -29,7 +30,7 @@ class FrameStack:
item: The frame object to push onto the stack
"""
self
.
_stack
.
append
(
item
)
if
hasattr
(
item
,
'
var
'
)
and
hasattr
(
item
,
'
value
'
):
if
hasattr
(
item
,
"
var
"
)
and
hasattr
(
item
,
"
value
"
):
self
.
_var_value_map
[
item
.
var
]
=
item
.
value
def
pop
(
self
):
...
...
@@ -43,7 +44,7 @@ class FrameStack:
"""
if
self
.
_stack
:
item
=
self
.
_stack
.
pop
()
if
hasattr
(
item
,
'
var
'
):
if
hasattr
(
item
,
"
var
"
):
self
.
_var_value_map
.
pop
(
item
.
var
,
None
)
return
item
raise
IndexError
(
f
"
{
self
.
__class__
.
__name__
}
is empty"
)
...
...
@@ -129,8 +130,7 @@ class LetFrame(TIRFrame):
is_block_load
=
True
break
if
is_block_load
:
self
.
value
=
BufferRegion
(
self
.
value
.
buffer
,
[
Range
(
x
.
base
,
x
.
lanes
)
for
x
in
indices
])
self
.
value
=
BufferRegion
(
self
.
value
.
buffer
,
[
Range
(
x
.
base
,
x
.
lanes
)
for
x
in
indices
])
_get_let_stack
().
push
(
self
)
return
self
.
var
...
...
tilelang/language/gemm.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
import
tilelang.language
as
T
...
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal
,
)
from
tilelang.language.utils
import
(
buffer_region_to_tile_region
,)
buffer_region_to_tile_region
,
)
from
tilelang.env
import
env
as
_env
...
...
@@ -68,12 +70,14 @@ def _gemm_impl(
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
if
len
(
A_shape
)
>
2
:
for
i
in
range
(
len
(
A_shape
)
-
2
):
assert
A_shape
[
i
]
==
1
,
\
assert
A_shape
[
i
]
==
1
,
(
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if
len
(
B_shape
)
>
2
:
for
i
in
range
(
len
(
B_shape
)
-
2
):
assert
B_shape
[
i
]
==
1
,
\
assert
B_shape
[
i
]
==
1
,
(
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M
,
N
=
C_shape
K
=
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
]
...
...
@@ -96,9 +100,29 @@ def _gemm_impl(
A_arg
=
buffer_region_to_tile_region
(
A_region
,
"r"
,
[
r
for
r
in
A_shape
])
B_arg
=
buffer_region_to_tile_region
(
B_region
,
"r"
,
[
r
for
r
in
B_shape
])
C_arg
=
buffer_region_to_tile_region
(
C_region
,
"rw"
,
[
r
for
r
in
C_shape
])
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
op_key
),
A_arg
,
B_arg
,
C_arg
,
transpose_A
,
transpose_B
,
M
,
N
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
offset_b
,
k_pack
,
wg_wait
,
mbar
,
C_coords
[
0
],
C_coords
[
1
])
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
op_key
),
A_arg
,
B_arg
,
C_arg
,
transpose_A
,
transpose_B
,
M
,
N
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
offset_b
,
k_pack
,
wg_wait
,
mbar
,
C_coords
[
0
],
C_coords
[
1
],
)
# Public wrappers
...
...
tilelang/language/kernel.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
collections
import
deque
from
tvm
import
tir
...
...
@@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame):
_get_current_stack
().
push
(
self
)
last_block_frame
=
self
.
frames
[
-
1
]
assert
isinstance
(
last_block_frame
,
BlockFrame
),
f
"Last frame must be a block frame, got
{
last_block_frame
}
"
assert
isinstance
(
last_block_frame
,
BlockFrame
),
f
"Last frame must be a block frame, got
{
last_block_frame
}
"
maybe_cpu
=
last_block_frame
.
annotations
.
get
(
"tilelang.is_cpu_kernel_frame"
,
False
)
...
...
@@ -303,56 +303,48 @@ def Kernel(
def
get_thread_binding
(
dim
:
int
=
0
)
->
Var
:
"""Returns the thread binding for the given dimension.
"""
"""Returns the thread binding for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_binding
(
dim
)
def
get_thread_bindings
()
->
list
[
Var
]:
"""Returns all three thread bindings.
"""
"""Returns all three thread bindings."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_bindings
()
def
get_block_binding
(
dim
:
int
=
0
)
->
Var
:
"""Returns the block binding for the given dimension.
"""
"""Returns the block binding for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_binding
(
dim
)
def
get_block_bindings
()
->
list
[
Var
]:
"""Returns all three block bindings.
"""
"""Returns all three block bindings."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_bindings
()
def
get_thread_extent
(
dim
:
int
=
0
)
->
int
:
"""Returns the thread extent for the given dimension.
"""
"""Returns the thread extent for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_extent
(
dim
)
def
get_thread_extents
()
->
list
[
int
]:
"""Returns all three thread extents.
"""
"""Returns all three thread extents."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_extents
()
def
get_block_extent
(
dim
:
int
=
0
)
->
int
:
"""Returns the block extent for the given dimension.
"""
"""Returns the block extent for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_extent
(
dim
)
def
get_block_extents
()
->
list
[
int
]:
"""Returns all three block extents.
"""
"""Returns all three block extents."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_extents
()
tilelang/language/logical.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang
import
language
as
T
...
...
@@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion):
)
new_region
.
append
(
r
.
min
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.any_of"
),
T
.
address_of
(
buffer_load
),
extent
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.any_of"
),
T
.
address_of
(
buffer_load
),
extent
)
else
:
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
...
...
@@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion):
)
new_region
.
append
(
r
.
min
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.all_of"
),
T
.
address_of
(
buffer_load
),
extent
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.all_of"
),
T
.
address_of
(
buffer_load
),
extent
)
else
:
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
tilelang/language/loop.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
typing
import
Any
from
tvm
import
tir
...
...
@@ -94,11 +95,9 @@ def Pipelined(
return
_ffi_api
.
Pipelined
(
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
def
serial
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
def
serial
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
step_is_one
=
False
step_is_one
|=
isinstance
(
step
,
int
)
and
step
==
1
step_is_one
|=
isinstance
(
step
,
IntImm
)
and
step
.
value
==
1
...
...
@@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr,
return
SerialForWithStep
(
start
,
stop
,
step
,
annotations
=
annotations
)
def
unroll
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
explicit
:
bool
=
False
,
unroll_factor
:
int
|
None
=
None
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
def
unroll
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
explicit
:
bool
=
False
,
unroll_factor
:
int
|
None
=
None
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
frame
.
ForFrame
:
"""The unrolled For statement.
Parameters
...
...
tilelang/language/math_intrinsics.py
View file @
29051439
...
...
@@ -3,7 +3,7 @@ from tvm import tir
def
_validate_rounding_mode
(
rounding_mode
):
"""Validate that the rounding mode is one of the supported IEEE modes"""
valid_modes
=
{
'
rn
'
,
'
rz
'
,
'
ru
'
,
'
rd
'
}
valid_modes
=
{
"
rn
"
,
"
rz
"
,
"
ru
"
,
"
rd
"
}
if
isinstance
(
rounding_mode
,
str
)
and
rounding_mode
in
valid_modes
:
return
raise
ValueError
(
f
"Invalid rounding mode '
{
rounding_mode
}
'. Must be one of:
{
valid_modes
}
"
)
...
...
tilelang/language/overrides/parser.py
View file @
29051439
"""TVMScript parser overrides tailored for TileLang."""
from
functools
import
partial
from
tvm.script.ir_builder
import
tir
as
T
...
...
@@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
continue
...
...
@@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
return
...
...
@@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
return
...
...
tilelang/language/parser/entry.py
View file @
29051439
...
...
@@ -18,6 +18,7 @@
# which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa
"""The entry point of TVM parser for tir."""
import
inspect
from
typing
import
Callable
,
Optional
,
Union
...
...
@@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils
from
tvm.script.parser.core.parser
import
Parser
,
ScriptMacro
def
prim_func
(
func
:
Optional
[
Callable
]
=
None
,
private
:
bool
=
False
,
check_well_formed
=
True
)
->
Union
[
PrimFunc
,
Callable
]:
def
prim_func
(
func
:
Optional
[
Callable
]
=
None
,
private
:
bool
=
False
,
check_well_formed
=
True
)
->
Union
[
PrimFunc
,
Callable
]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
...
...
@@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
if
len
(
args
)
==
1
and
inspect
.
isfunction
(
args
[
0
]):
return
_decorator
(
args
[
0
])
raise
ValueError
(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
)
raise
ValueError
(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
)
class
BufferProxy
:
...
...
tilelang/language/parser/operation.py
View file @
29051439
...
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration"""
from
tvm
import
tir
from
tvm.ffi.runtime_ctypes
import
DataType
,
DataTypeCode
from
tvm.tir
import
IntImm
...
...
@@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
return
dtype
[
0
:
index
]
def
_auto_broadcast
(
a
,
b
,
op
):
if
isinstance
(
a
,
int
):
if
hasattr
(
b
,
"dtype"
):
if
(
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
UINT
):
if
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
UINT
:
a
=
IntImm
(
_get_type_str
(
b
.
dtype
),
a
)
elif
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
a
=
FloatImm
(
_get_type_str
(
b
.
dtype
),
a
)
...
...
@@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
assert
isinstance
(
a
,
tir
.
PrimExpr
),
"Operand should be a PrimExpr."
if
isinstance
(
b
,
int
):
if
(
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
UINT
):
if
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
UINT
:
b
=
IntImm
(
_get_type_str
(
a
.
dtype
),
b
)
elif
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
b
=
FloatImm
(
_get_type_str
(
a
.
dtype
),
b
)
...
...
@@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
if
DataType
(
a
.
dtype
).
lanes
==
DataType
(
b
.
dtype
).
lanes
:
return
op
(
a
,
b
)
elif
(
DataType
(
a
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
)
:
elif
DataType
(
a
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
:
broadcast_a
=
tir
.
Broadcast
(
a
,
DataType
(
b
.
dtype
).
lanes
)
return
op
(
broadcast_a
,
b
)
elif
(
DataType
(
b
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
)
:
elif
DataType
(
b
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
:
broadcast_b
=
tir
.
Broadcast
(
b
,
DataType
(
a
.
dtype
).
lanes
)
return
op
(
a
,
broadcast_b
)
else
:
...
...
tilelang/language/parser/parser.py
View file @
29051439
...
...
@@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
res
=
value
.
__enter__
()
IRBuilder
.
name
(
var_name
,
res
)
return
res
elif
isinstance
(
value
,
(
Buffer
,
IterVar
))
or
(
isinstance
(
value
,
Var
)
and
not
self
.
var_table
.
exist
(
value
)):
elif
isinstance
(
value
,
(
Buffer
,
IterVar
))
or
(
isinstance
(
value
,
Var
)
and
not
self
.
var_table
.
exist
(
value
)):
IRBuilder
.
name
(
var_name
,
value
)
return
value
else
:
...
...
@@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None:
if
not
isinstance
(
for_frame
,
T
.
frame
.
ForFrame
):
self
.
report_error
(
node
.
iter
,
"Expect the for loop to be one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding"
,
"Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding"
,
)
with
self
.
var_table
.
with_frame
():
with
for_frame
as
iters
:
...
...
@@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None:
for
item
in
node
.
items
:
frame
=
self
.
eval_expr
(
item
.
context_expr
)
if
not
isinstance
(
frame
,
Frame
):
self
.
report_error
(
item
.
context_expr
,
"Invalid context expression in the with-statement."
)
self
.
report_error
(
item
.
context_expr
,
"Invalid context expression in the with-statement."
)
rhs
=
stack
.
enter_context
(
frame
)
if
item
.
optional_vars
is
not
None
:
self
.
eval_assign
(
target
=
item
.
optional_vars
,
source
=
rhs
,
bind_value
=
bind_with_value
)
...
...
@@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None:
with
self
.
var_table
.
with_frame
():
self
.
visit_body
(
node
.
orelse
)
else
:
self
.
report_error
(
node
.
test
,
f
"If condition must be a boolean expression, but got
{
predicate
}
"
)
self
.
report_error
(
node
.
test
,
f
"If condition must be a boolean expression, but got
{
predicate
}
"
)
@
dispatch
.
register
(
token
=
"tir"
,
type_name
=
"Assert"
)
...
...
tilelang/language/print.py
View file @
29051439
...
...
@@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
@
macro
def
print_var_with_condition
(
condition
:
tir
.
PrimExpr
,
var
:
tir
.
PrimExpr
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_var_with_condition
(
condition
:
tir
.
PrimExpr
,
var
:
tir
.
PrimExpr
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
...
...
@@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr,
@
macro
def
print_global_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_global_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
"""
...
...
@@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
else
:
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
@
macro
def
print_shared_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_shared_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
...
@@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
@
macro
def
print_fragment_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_fragment_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
...
@@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
@
macro
def
print_local_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_local_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
...
@@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
from
tilelang.utils.target
import
check_cuda_availability
...
...
@@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems
*=
dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition
=
(
tx
==
main_lane
and
ty
==
0
and
tz
==
0
)
condition
=
tx
==
main_lane
and
ty
==
0
and
tz
==
0
if
not
msg
:
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
return
print_fragment_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
...
...
@@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems
*=
dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition
=
(
tx
==
main_lane
and
ty
==
0
and
tz
==
0
)
condition
=
tx
==
main_lane
and
ty
==
0
and
tz
==
0
if
not
msg
:
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
return
print_shared_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
...
...
@@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
else
:
# Unsupported object type.
raise
ValueError
(
f
"Unexpected type:
{
type
(
obj
)
}
. Supported types are tir.Buffer and tir.PrimExpr."
)
raise
ValueError
(
f
"Unexpected type:
{
type
(
obj
)
}
. Supported types are tir.Buffer and tir.PrimExpr."
)
tilelang/language/proxy.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
typing
import
Any
,
SupportsIndex
,
TYPE_CHECKING
,
Generic
,
TypeVar
...
...
@@ -51,11 +52,9 @@ class BufferProxy:
return
self
(
keys
)
return
self
(
*
keys
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Buffer
:
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Buffer
:
"""Create a buffer from a pointer, shape, and data type.
Args:
...
...
@@ -76,6 +75,7 @@ class BaseTensorProxy:
customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations.
"""
default_scope
=
"global"
default_align
=
0
default_offset_factor
=
0
...
...
@@ -118,11 +118,9 @@ class BaseTensorProxy:
keys
=
(
keys
,)
return
self
(
*
keys
)
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
"""Create a buffer from a pointer, shape, and data type.
Args:
...
...
@@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy):
strides
.
append
(
s
)
return
tuple
(
reversed
(
strides
))
def
__call__
(
self
,
shape
:
tuple
[
Any
]
|
PrimExpr
|
int
,
dtype
:
str
=
"float32"
,
data
=
None
,
scope
=
None
)
->
tir
.
Buffer
:
def
__call__
(
self
,
shape
:
tuple
[
Any
]
|
PrimExpr
|
int
,
dtype
:
str
=
"float32"
,
data
=
None
,
scope
=
None
)
->
tir
.
Buffer
:
if
isinstance
(
shape
,
(
int
,
PrimExpr
)):
shape
=
(
shape
,)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
TensorProxy
.
_construct_strides
(
shape
),
data
=
data
,
scope
=
scope
)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
TensorProxy
.
_construct_strides
(
shape
),
data
=
data
,
scope
=
scope
)
class
StridedTensorProxy
(
BaseTensorProxy
):
...
...
@@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy):
This class implements the default tensor proxy with global memory scope, with the stride information required.
"""
def
__call__
(
self
,
shape
:
tuple
[
Any
],
strides
:
tuple
[
Any
],
dtype
:
str
=
"float32"
,
scope
=
None
)
->
tir
.
Buffer
:
def
__call__
(
self
,
shape
:
tuple
[
Any
],
strides
:
tuple
[
Any
],
dtype
:
str
=
"float32"
,
scope
=
None
)
->
tir
.
Buffer
:
if
len
(
shape
)
!=
len
(
strides
):
raise
ValueError
(
"Invalid shape/strides' dimensions"
)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
strides
,
scope
=
scope
)
...
...
@@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy):
This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations.
"""
default_scope
=
"local.fragment"
...
...
@@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy):
This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations.
"""
default_scope
=
"shared.dyn"
...
...
@@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy):
This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels.
"""
default_scope
=
"local"
...
...
@@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name
if
TYPE_CHECKING
:
class
BaseTensor
:
def
__class_getitem__
(
cls
,
key
):
return
cls
def
__getitem__
(
self
,
key
)
->
Any
:
...
def
__getitem__
(
self
,
key
)
->
Any
:
...
def
__setitem__
(
self
,
key
,
value
)
->
None
:
...
def
__setitem__
(
self
,
key
,
value
)
->
None
:
...
def
__init__
(
self
,
...
...
@@ -238,36 +223,26 @@ if TYPE_CHECKING:
offset_factor
=
None
,
buffer_type
=
""
,
axis_separators
=
None
,
):
...
):
...
@
classmethod
def
from_ptr
(
cls
,
pointer_var
:
Var
,
shape
:
Sequence
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Self
:
...
def
from_ptr
(
cls
,
pointer_var
:
Var
,
shape
:
Sequence
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Self
:
...
class
Tensor
(
BaseTensor
):
...
class
Tensor
(
BaseTensor
):
...
class
StridedTensor
(
BaseTensor
):
...
class
StridedTensor
(
BaseTensor
):
...
class
FragmentBuffer
(
BaseTensor
):
...
class
FragmentBuffer
(
BaseTensor
):
...
class
SharedBuffer
(
BaseTensor
):
...
class
SharedBuffer
(
BaseTensor
):
...
class
LocalBuffer
(
BaseTensor
):
...
class
LocalBuffer
(
BaseTensor
):
...
_T
=
TypeVar
(
'
_T
'
)
_T
=
TypeVar
(
"
_T
"
)
class
Ref
(
Generic
[
_T
],
tir
.
Var
):
...
class
Ref
(
Generic
[
_T
],
tir
.
Var
):
...
else
:
Tensor
=
TensorProxy
()
# pylint: disable=invalid-name
StridedTensor
=
StridedTensorProxy
()
# pylint: disable=invalid-name
...
...
@@ -275,14 +250,10 @@ else:
SharedBuffer
=
SharedBufferProxy
()
# pylint: disable=invalid-name
LocalBuffer
=
LocalBufferProxy
()
# pylint: disable=invalid-name
class
Ref
:
...
class
Ref
:
...
def
ptr
(
dtype
:
str
|
None
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
def
ptr
(
dtype
:
str
|
None
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
"""Create a TIR var that represents a pointer.
Parameters
...
...
@@ -304,8 +275,5 @@ def ptr(dtype: str | None = None,
return
handle
(
dtype
=
dtype
,
storage_scope
=
storage_scope
,
is_size_var
=
is_size_var
)
def
make_tensor
(
ptr
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
def
make_tensor
(
ptr
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
return
Tensor
.
from_ptr
(
ptr
,
shape
,
dtype
,
strides
)
Prev
1
…
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment