Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
281 additions
and
328 deletions
+281
-328
tilelang/language/ast/_ffi_api.py
tilelang/language/ast/_ffi_api.py
+1
-0
tilelang/language/ast/ir.py
tilelang/language/ast/ir.py
+39
-48
tilelang/language/atomic.py
tilelang/language/atomic.py
+7
-19
tilelang/language/builtin.py
tilelang/language/builtin.py
+48
-62
tilelang/language/copy.py
tilelang/language/copy.py
+36
-22
tilelang/language/customize.py
tilelang/language/customize.py
+6
-6
tilelang/language/experimental/gemm_sp.py
tilelang/language/experimental/gemm_sp.py
+8
-5
tilelang/language/fill.py
tilelang/language/fill.py
+3
-4
tilelang/language/frame.py
tilelang/language/frame.py
+4
-4
tilelang/language/gemm.py
tilelang/language/gemm.py
+30
-6
tilelang/language/kernel.py
tilelang/language/kernel.py
+10
-18
tilelang/language/logical.py
tilelang/language/logical.py
+3
-4
tilelang/language/loop.py
tilelang/language/loop.py
+13
-12
tilelang/language/math_intrinsics.py
tilelang/language/math_intrinsics.py
+1
-1
tilelang/language/overrides/parser.py
tilelang/language/overrides/parser.py
+19
-6
tilelang/language/parser/entry.py
tilelang/language/parser/entry.py
+3
-5
tilelang/language/parser/operation.py
tilelang/language/parser/operation.py
+5
-7
tilelang/language/parser/parser.py
tilelang/language/parser/parser.py
+4
-8
tilelang/language/print.py
tilelang/language/print.py
+11
-29
tilelang/language/proxy.py
tilelang/language/proxy.py
+30
-62
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/language/ast/_ffi_api.py
View file @
29051439
...
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""FFI APIs"""
import
tvm.ffi
tvm
.
ffi
.
_init_api
(
"script.ir_builder.tir"
,
__name__
)
# pylint: disable=protected-access
tilelang/language/ast/ir.py
View file @
29051439
...
...
@@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisSpatial
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
reduce
(
...
...
@@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisReduce
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
scan
(
...
...
@@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisScan
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
opaque
(
...
...
@@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
"""
return
_ffi_api
.
AxisOpaque
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
def
remap
(
kinds
:
str
,
bindings
:
List
[
PrimExpr
],
dtype
:
str
=
"int32"
)
->
Union
[
List
[
Var
],
Var
]:
...
...
@@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name
The iteration variables.
"""
iter_vars
=
_ffi_api
.
AxisRemap
(
# type: ignore[attr-defined] # pylint: disable=no-member
kinds
,
bindings
,
dtype
)
kinds
,
bindings
,
dtype
)
return
iter_vars
[
0
]
if
len
(
iter_vars
)
==
1
else
iter_vars
S
=
spatial
# pylint: disable=invalid-name
R
=
reduce
# pylint: disable=invalid-name
def
serial
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
serial
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The serial For statement.
Parameters
...
...
@@ -700,10 +702,7 @@ def serial(start: PrimExpr,
return
_ffi_api
.
Serial
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
parallel
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
parallel
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The parallel For statement.
Parameters
...
...
@@ -731,10 +730,7 @@ def parallel(start: PrimExpr,
return
_ffi_api
.
Parallel
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
vectorized
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
vectorized
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The vectorized For statement.
Parameters
...
...
@@ -762,10 +758,7 @@ def vectorized(start: PrimExpr,
return
_ffi_api
.
Vectorized
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
unroll
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
def
unroll
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The unrolled For statement.
Parameters
...
...
@@ -837,7 +830,8 @@ def thread_binding(
else
:
start
=
0
return
_ffi_api
.
ThreadBinding
(
# type: ignore[attr-defined] # pylint: disable=no-member
start
,
stop
,
thread
,
annotations
)
start
,
stop
,
thread
,
annotations
)
def
grid
(
*
extents
:
PrimExpr
)
->
frame
.
ForFrame
:
...
...
@@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d
def
LetStmt
(
# pylint: disable=invalid-name
value
:
PrimExpr
,
type_annotation
:
Optional
[
Type
]
=
None
,
# pylint: disable=redefined-outer-name
*
,
var
:
Optional
[
Var
]
=
None
,
# pylint: disable=redefined-outer-name
value
:
PrimExpr
,
type_annotation
:
Optional
[
Type
]
=
None
,
# pylint: disable=redefined-outer-name
*
,
var
:
Optional
[
Var
]
=
None
,
# pylint: disable=redefined-outer-name
)
->
frame
.
LetFrame
:
"""Create a LetStmt binding
...
...
@@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name
def
Let
(
# pylint: disable=invalid-name
expr
:
PrimExpr
,
where
:
Dict
[
Var
,
PrimExpr
],
# pylint: disable=redefined-outer-name
expr
:
PrimExpr
,
where
:
Dict
[
Var
,
PrimExpr
],
# pylint: disable=redefined-outer-name
)
->
PrimExpr
:
"""Create a Let expression binding"""
assert
len
(
where
)
==
1
,
"T.Let only allows `where` to have exactly one element"
...
...
@@ -980,7 +974,8 @@ def realize(
The result RealizeFrame.
"""
return
_ffi_api
.
Realize
(
# type: ignore[attr-defined] # pylint: disable=no-member
buffer_slice
,
storage_scope
,
condition
)
buffer_slice
,
storage_scope
,
condition
)
def
allocate
(
...
...
@@ -1012,7 +1007,8 @@ def allocate(
if
isinstance
(
condition
,
bool
):
condition
=
IntImm
(
"bool"
,
condition
)
return
_ffi_api
.
Allocate
(
# type: ignore[attr-defined] # pylint: disable=no-member
extents
,
dtype
,
scope
,
condition
,
annotations
)
extents
,
dtype
,
scope
,
condition
,
annotations
)
def
allocate_const
(
...
...
@@ -1048,7 +1044,8 @@ def allocate_const(
np_data
=
np_data
.
reshape
(
extents
)
return
_ffi_api
.
AllocateConst
(
# type: ignore[attr-defined] # pylint: disable=no-member
ndarray
.
array
(
np_data
),
dtype
,
extents
,
annotations
)
ndarray
.
array
(
np_data
),
dtype
,
extents
,
annotations
)
def
attr
(
node
:
Any
,
attr_key
:
str
,
value
:
Union
[
PrimExpr
,
str
])
->
frame
.
AttrFrame
:
...
...
@@ -1297,7 +1294,8 @@ def buffer_store(
if
isinstance
(
value
,
bool
)
and
buffer
.
dtype
==
"bool"
:
value
=
IntImm
(
"bool"
,
value
)
return
_ffi_api
.
BufferStore
(
# type: ignore[attr-defined] # pylint: disable=no-member
buffer
,
value
,
expr_indices
)
buffer
,
value
,
expr_indices
)
def
prefetch
(
...
...
@@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return
_ffi_api
.
Boolean
(
expr
,
is_size_var
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
handle
(
dtype
:
Optional
[
str
]
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
def
handle
(
dtype
:
Optional
[
str
]
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
"""Create a TIR var that represents a pointer.
Parameters
...
...
@@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
res
=
combiner
(
*
args
)
if
not
isinstance
(
res
,
tuple
):
res
=
(
res
,)
return
CommReducer
(
args
[:
num_args
//
2
],
args
[
num_args
//
2
:],
res
,
identity
)
return
CommReducer
(
args
[:
num_args
//
2
],
args
[
num_args
//
2
:],
res
,
identity
)
def
index_map
(
...
...
@@ -1700,16 +1695,15 @@ def target(
The target.
"""
if
not
isinstance
(
target_config
,
(
str
,
dict
)):
raise
ValueError
(
f
"T.target expected a config dict or string, but got
{
type
(
target_config
)
}
"
)
raise
ValueError
(
f
"T.target expected a config dict or string, but got
{
type
(
target_config
)
}
"
)
if
host
is
not
None
and
not
isinstance
(
host
,
(
str
,
dict
,
Target
)):
raise
ValueError
(
"T.target expected the host to be "
"a config dict, string, or T.target, "
f
"but got
{
type
(
host
)
}
"
)
raise
ValueError
(
f
"T.target expected the host to be a config dict, string, or T.target, but got
{
type
(
host
)
}
"
)
if
isinstance
(
target_config
,
dict
)
and
"host"
in
target_config
and
host
is
not
None
:
raise
ValueError
(
"T.target expects to either receive the host "
"as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
raise
ValueError
(
"T.target expects to either receive the host "
"as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
return
Target
(
target_config
,
host
)
...
...
@@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name
self
.
value
=
value
def
__iter__
(
self
):
def
f
():
for
i
in
self
.
value
:
yield
meta_var
(
i
)
...
...
@@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name
def
_op_wrapper
(
func
):
@
functools
.
wraps
(
func
)
def
wrapped
(
*
args
,
**
kwargs
):
if
"dtype"
in
kwargs
:
...
...
@@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale)
def
_dtype_forward
(
func
):
@
functools
.
wraps
(
func
)
def
wrapped
(
*
args
,
**
kwargs
):
if
"dtype"
in
kwargs
:
...
...
tilelang/language/atomic.py
View file @
29051439
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""Atomic operations for tilelang."""
from
__future__
import
annotations
import
tilelang.language
as
T
...
...
@@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = {
}
def
atomic_max
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
def
atomic_max
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
...
...
@@ -64,10 +62,7 @@ def atomic_max(dst: Buffer,
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_min
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
def
atomic_min
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
"""
Atomically update the value at dst to the minimum of its current value and value.
...
...
@@ -112,11 +107,7 @@ def atomic_min(dst: Buffer,
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_add
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
,
use_tma
:
bool
=
False
)
->
PrimExpr
:
def
atomic_add
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
,
use_tma
:
bool
=
False
)
->
PrimExpr
:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
...
...
@@ -191,8 +182,7 @@ def atomic_add(dst: Buffer,
if
memory_order
is
None
:
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
)
else
:
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
if
isinstance
(
dst
,
Buffer
)
and
isinstance
(
value
,
Buffer
):
ir
.
assert_structural_equal
(
dst
.
shape
,
value
.
shape
)
...
...
@@ -208,14 +198,12 @@ def atomic_add(dst: Buffer,
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
if
return_prev
:
raise
NotImplementedError
(
"return_prev is not supported for tile-region-based atomic operations"
)
raise
NotImplementedError
(
"return_prev is not supported for tile-region-based atomic operations"
)
if
memory_order
is
None
:
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
0
)
else
:
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_addx2
(
dst
:
Buffer
,
value
:
PrimExpr
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
...
...
tilelang/language/builtin.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang
import
tvm
as
tvm
...
...
@@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int):
def
inc_max_nreg
(
reg_count
:
int
):
"""Increment the maximum number of registers to use.
"""
"""Increment the maximum number of registers to use."""
return
set_max_nreg
(
reg_count
,
1
)
def
dec_max_nreg
(
reg_count
:
int
):
"""Decrement the maximum number of registers to use.
"""
"""Decrement the maximum number of registers to use."""
return
set_max_nreg
(
reg_count
,
0
)
def
annotate_producer_reg_dealloc
(
reg_count
:
int
=
24
):
"""Annotate the producer reg dealloc.
"""
"""Annotate the producer reg dealloc."""
return
dec_max_nreg
(
reg_count
)
def
annotate_consumer_reg_alloc
(
reg_count
:
int
=
240
):
"""Annotate the consumer reg alloc.
"""
"""Annotate the consumer reg alloc."""
return
inc_max_nreg
(
reg_count
)
def
no_set_max_nreg
():
"""Disable the maximum register limit setting.
"""
"""Disable the maximum register limit setting."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.no_set_max_nreg"
))
def
disable_warp_group_reg_alloc
():
"""Disable the warp group reg alloc.
"""
"""Disable the warp group reg alloc."""
return
no_set_max_nreg
()
...
...
@@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int):
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.warpgroup_wait"
),
num_mma
)
def
get_lane_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_lane_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the logical lane index of the calling thread within a warp.
Parameters
...
...
@@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_lane_idx"
),
warp_size_expr
)
def
get_warp_idx_sync
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_warp_idx_sync
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
...
...
@@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_idx_sync"
),
warp_size_expr
)
def
get_warp_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_warp_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the canonical warp index without synchronizing the warp.
Parameters
...
...
@@ -429,8 +430,7 @@ def get_warp_group_idx(
args
.
append
(
warp_size_expr
)
if
warps_per_group_expr
is
not
None
:
if
warp_size_expr
is
None
:
raise
ValueError
(
"get_warp_group_idx expects `warp_size` when specifying "
"`warps_per_group`."
)
raise
ValueError
(
"get_warp_group_idx expects `warp_size` when specifying `warps_per_group`."
)
args
.
append
(
warps_per_group_expr
)
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_group_idx"
),
*
args
)
...
...
@@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return
tir
.
call_intrin
(
"bool"
,
tir
.
op
.
Op
.
get
(
"tl.tl_shuffle_elect"
),
thread_extent
)
def
warpgroup_fence_operand
(
buffer_or_ptr
:
tir
.
Buffer
|
PrimExpr
,
offset
:
int
|
PrimExpr
=
0
,
num_regs
:
int
|
PrimExpr
|
None
=
None
,
dtype
:
str
|
None
=
None
):
def
warpgroup_fence_operand
(
buffer_or_ptr
:
tir
.
Buffer
|
PrimExpr
,
offset
:
int
|
PrimExpr
=
0
,
num_regs
:
int
|
PrimExpr
|
None
=
None
,
dtype
:
str
|
None
=
None
):
"""Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
...
...
@@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
)
)
if
isinstance
(
buffer_or_ptr
,
tir
.
Buffer
):
data_ptr
=
buffer_or_ptr
.
data
...
...
@@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
if
isinstance
(
dim
,
tir
.
IntImm
):
total_elems
*=
int
(
dim
)
else
:
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic."
)
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic."
)
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
elif
isinstance
(
buffer_or_ptr
,
BufferRegion
):
...
...
@@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
else
:
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
...
...
@@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
)
)
else
:
data_ptr
=
buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided
...
...
@@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
except
Exception
:
inferred
=
None
if
inferred
is
None
:
raise
ValueError
(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
raise
ValueError
(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype
=
inferred
if
num_regs
is
None
:
raise
ValueError
(
"num_regs must be provided when passing a pointer expression."
)
...
...
@@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
)
)
def
wait_wgmma
(
id
:
int
):
...
...
@@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor"
,
value
,
offset
)
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
shfl_down
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
...
...
@@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down"
,
value
,
offset
)
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
shfl_up
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
...
...
@@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call)
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up"
,
value
,
offset
)
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
sync_threads
(
barrier_id
:
int
=
None
,
arrive_count
:
int
=
None
):
"""Synchronize all threads in a block.
"""
"""Synchronize all threads in a block."""
args
=
[]
if
barrier_id
is
not
None
:
args
.
append
(
barrier_id
)
...
...
@@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
def
sync_global
():
"""Synchronize all threads in the entire grid.
"""
"""Synchronize all threads in the entire grid."""
tx
,
ty
,
tz
=
get_thread_bindings
()
ex
,
ey
,
ez
=
get_block_extents
()
print
(
tx
,
ty
,
tz
,
ex
,
ey
,
ez
)
...
...
@@ -724,8 +719,7 @@ def sync_global():
def
sync_grid
():
"""Synchronize all threads in a grid.
"""
"""Synchronize all threads in a grid."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.sync_grid"
))
...
...
@@ -741,12 +735,10 @@ def initialize_wgmma_descriptor(
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
return
evaluate
(
tir
.
call_intrin
(
...
...
@@ -757,7 +749,8 @@ def initialize_wgmma_descriptor(
layout_type_
,
int
(
leading_byte_offset
),
int
(
stride_byte_offset
),
))
)
)
def
initialize_tcgen05_descriptor
(
...
...
@@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor(
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
return
evaluate
(
tir
.
call_intrin
(
...
...
@@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor(
int
(
base_offset
),
tir
.
IntImm
(
"int32"
,
1
if
leading_is_absolute
else
0
),
int
(
swizzle_mode
),
))
)
)
def
increase_descriptor_offset
(
descriptor
:
PrimExpr
,
offset
:
PrimExpr
)
->
PrimExpr
:
...
...
@@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.increase_descriptor_offset"
),
descriptor
,
offset
))
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.increase_descriptor_offset"
),
descriptor
,
offset
))
def
loop_break
():
"""Break out of the innermost loop.
"""
"""Break out of the innermost loop."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.loop_break"
))
def
cp_async_barrier_noinc
(
barrier_id
:
int
|
PrimExpr
|
tir
.
Call
):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc."""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ptx_cp_async_barrier_noinc"
),
barrier_id
)
...
...
tilelang/language/copy.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
typing
import
Literal
from
tilelang
import
language
as
T
...
...
@@ -10,11 +11,13 @@ from tilelang.utils.language import (
from
tvm
import
ir
,
tir
def
copy
(
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
dst
:
tir
.
Buffer
|
tir
.
BufferLoad
,
coalesced_width
:
int
|
None
=
None
,
disable_tma
:
bool
=
False
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
):
def
copy
(
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
dst
:
tir
.
Buffer
|
tir
.
BufferLoad
,
coalesced_width
:
int
|
None
=
None
,
disable_tma
:
bool
=
False
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
,
):
"""Copy data between memory regions.
Args:
...
...
@@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
src_extent
=
get_extent
(
src
)
dst_extent
=
get_extent
(
dst
)
# Combine the nested if statements into a single if statement as suggested by SIM102
if
(
src_extent
is
None
and
dst_extent
is
None
and
isinstance
(
src
,
tir
.
BufferLoad
)
and
isinstance
(
dst
,
tir
.
BufferLoad
)):
if
src_extent
is
None
and
dst_extent
is
None
and
isinstance
(
src
,
tir
.
BufferLoad
)
and
isinstance
(
dst
,
tir
.
BufferLoad
):
# check if the case is like this:
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
...
...
@@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
eviction_policy
=
0
else
:
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.copy"
),
src
,
dst
,
coalesced_width
,
disable_tma
,
eviction_policy
)
def
c2d_im2col
(
img
:
tir
.
Buffer
,
col
:
tir
.
Buffer
,
nhw_step
:
tir
.
PrimExpr
,
c_step
:
tir
.
PrimExpr
,
kernel
:
int
,
stride
:
int
,
dilation
:
int
,
pad
:
int
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
):
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.copy"
),
src
,
dst
,
coalesced_width
,
disable_tma
,
eviction_policy
)
def
c2d_im2col
(
img
:
tir
.
Buffer
,
col
:
tir
.
Buffer
,
nhw_step
:
tir
.
PrimExpr
,
c_step
:
tir
.
PrimExpr
,
kernel
:
int
,
stride
:
int
,
dilation
:
int
,
pad
:
int
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
,
):
"""Perform im2col transformation for 2D convolution.
Args:
...
...
@@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
img_region
=
to_buffer_region
(
img
,
access_type
=
"r"
)
col_region
=
to_buffer_region
(
col
,
access_type
=
"w"
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.c2d_im2col"
),
img_region
,
col_region
,
nhw_step
,
c_step
,
kernel
,
stride
,
dilation
,
pad
,
eviction_policy
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.c2d_im2col"
),
img_region
,
col_region
,
nhw_step
,
c_step
,
kernel
,
stride
,
dilation
,
pad
,
eviction_policy
,
)
tilelang/language/customize.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
import
tilelang.language
as
T
from
tvm.tir
import
PrimExpr
,
Buffer
,
op
from
tilelang.utils.language
import
(
bits_product
,
prim_expr_equal
)
from
tilelang.utils.language
import
bits_product
,
prim_expr_equal
from
.atomic
import
atomic_max
,
atomic_min
,
atomic_add
,
atomic_addx2
,
atomic_addx4
,
atomic_load
,
atomic_store
# noqa: F401
...
...
@@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns:
Buffer: A new buffer view with the specified shape
"""
assert
prim_expr_equal
(
bits_product
(
shape
,
src
.
dty
pe
)
,
bits_product
(
src
.
shape
,
src
.
dtype
)
)
,
f
"T.reshape/view shape check failed. src
{
src
}
src.shape:
{
src
.
shape
}
, src.dtype:
{
src
.
dtype
}
, target shape:
{
shape
}
, target dtype:
{
src
.
dtype
}
"
assert
prim_expr_equal
(
bits_product
(
shape
,
src
.
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
(
f
"T.reshape/view shape check failed. src
{
src
}
src.
shape
:
{
src
.
sha
pe
}
,
src.dtype:
{
src
.
dtype
}
, target shape:
{
shape
}
, target dtype:
{
src
.
dtype
}
"
)
return
T
.
Tensor
(
shape
,
src
.
dtype
,
src
.
data
)
...
...
@@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N
shape
=
src
.
shape
if
dtype
is
None
:
dtype
=
src
.
dtype
assert
prim_expr_equal
(
bits_product
(
shape
,
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
"T.reshape/view shape check failed."
assert
prim_expr_equal
(
bits_product
(
shape
,
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
"T.reshape/view shape check failed."
return
T
.
Tensor
(
shape
,
dtype
,
src
.
data
)
...
...
tilelang/language/experimental/gemm_sp.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
import
tilelang.language
as
T
...
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal
,
)
from
tilelang.language.utils
import
(
buffer_region_to_tile_region
,)
buffer_region_to_tile_region
,
)
def
gemm_sp
(
...
...
@@ -169,18 +171,19 @@ def gemm_sp_v2(
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
if
len
(
A_shape
)
>
2
:
for
i
in
range
(
len
(
A_shape
)
-
2
):
assert
A_shape
[
i
]
==
1
,
\
assert
A_shape
[
i
]
==
1
,
(
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if
len
(
B_shape
)
>
2
:
for
i
in
range
(
len
(
B_shape
)
-
2
):
assert
B_shape
[
i
]
==
1
,
\
assert
B_shape
[
i
]
==
1
,
(
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M
,
N
=
C_shape
K
=
2
*
(
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
])
K_B
=
B_shape
[
-
1
]
if
transpose_B
else
B_shape
[
-
2
]
assert
prim_expr_equal
(
K
,
K_B
),
f
"T.gemm_sp K shape check failed: K_A (wo sparse) =
{
K
}
, K_B =
{
K_B
}
"
assert
prim_expr_equal
(
K
,
K_B
),
f
"T.gemm_sp K shape check failed: K_A (wo sparse) =
{
K
}
, K_B =
{
K_B
}
"
stride_a
=
A_stride
[
-
2
]
stride_b
=
B_stride
[
-
2
]
...
...
tilelang/language/fill.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tvm
import
tir
from
tilelang.language
import
has_let_value
,
get_let_value
...
...
@@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
extents
=
[
tir
.
IntImm
(
"int32"
,
1
)
for
_
in
buffer
.
indices
]
else
:
extents
=
[]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.fill"
),
to_buffer_region
(
buffer
,
access_type
=
"w"
,
extents
=
extents
),
value
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.fill"
),
to_buffer_region
(
buffer
,
access_type
=
"w"
,
extents
=
extents
),
value
)
def
clear
(
buffer
:
tir
.
Buffer
|
tir
.
Var
):
...
...
@@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var):
elif
isinstance
(
buffer_region
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
buffer_region
)
if
region
is
None
:
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
return
fill
(
region
,
0
)
else
:
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
...
...
tilelang/language/frame.py
View file @
29051439
"""Override the LetFrame to print a message when entering the frame."""
from
__future__
import
annotations
from
tvm.ffi
import
register_object
as
_register_object
from
tvm.tir
import
Var
,
PrimExpr
,
BufferLoad
,
BufferRegion
...
...
@@ -29,7 +30,7 @@ class FrameStack:
item: The frame object to push onto the stack
"""
self
.
_stack
.
append
(
item
)
if
hasattr
(
item
,
'
var
'
)
and
hasattr
(
item
,
'
value
'
):
if
hasattr
(
item
,
"
var
"
)
and
hasattr
(
item
,
"
value
"
):
self
.
_var_value_map
[
item
.
var
]
=
item
.
value
def
pop
(
self
):
...
...
@@ -43,7 +44,7 @@ class FrameStack:
"""
if
self
.
_stack
:
item
=
self
.
_stack
.
pop
()
if
hasattr
(
item
,
'
var
'
):
if
hasattr
(
item
,
"
var
"
):
self
.
_var_value_map
.
pop
(
item
.
var
,
None
)
return
item
raise
IndexError
(
f
"
{
self
.
__class__
.
__name__
}
is empty"
)
...
...
@@ -129,8 +130,7 @@ class LetFrame(TIRFrame):
is_block_load
=
True
break
if
is_block_load
:
self
.
value
=
BufferRegion
(
self
.
value
.
buffer
,
[
Range
(
x
.
base
,
x
.
lanes
)
for
x
in
indices
])
self
.
value
=
BufferRegion
(
self
.
value
.
buffer
,
[
Range
(
x
.
base
,
x
.
lanes
)
for
x
in
indices
])
_get_let_stack
().
push
(
self
)
return
self
.
var
...
...
tilelang/language/gemm.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
import
tilelang.language
as
T
...
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal
,
)
from
tilelang.language.utils
import
(
buffer_region_to_tile_region
,)
buffer_region_to_tile_region
,
)
from
tilelang.env
import
env
as
_env
...
...
@@ -68,12 +70,14 @@ def _gemm_impl(
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
if
len
(
A_shape
)
>
2
:
for
i
in
range
(
len
(
A_shape
)
-
2
):
assert
A_shape
[
i
]
==
1
,
\
assert
A_shape
[
i
]
==
1
,
(
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if
len
(
B_shape
)
>
2
:
for
i
in
range
(
len
(
B_shape
)
-
2
):
assert
B_shape
[
i
]
==
1
,
\
assert
B_shape
[
i
]
==
1
,
(
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M
,
N
=
C_shape
K
=
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
]
...
...
@@ -96,9 +100,29 @@ def _gemm_impl(
A_arg
=
buffer_region_to_tile_region
(
A_region
,
"r"
,
[
r
for
r
in
A_shape
])
B_arg
=
buffer_region_to_tile_region
(
B_region
,
"r"
,
[
r
for
r
in
B_shape
])
C_arg
=
buffer_region_to_tile_region
(
C_region
,
"rw"
,
[
r
for
r
in
C_shape
])
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
op_key
),
A_arg
,
B_arg
,
C_arg
,
transpose_A
,
transpose_B
,
M
,
N
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
offset_b
,
k_pack
,
wg_wait
,
mbar
,
C_coords
[
0
],
C_coords
[
1
])
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
op_key
),
A_arg
,
B_arg
,
C_arg
,
transpose_A
,
transpose_B
,
M
,
N
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
offset_b
,
k_pack
,
wg_wait
,
mbar
,
C_coords
[
0
],
C_coords
[
1
],
)
# Public wrappers
...
...
tilelang/language/kernel.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
collections
import
deque
from
tvm
import
tir
...
...
@@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame):
_get_current_stack
().
push
(
self
)
last_block_frame
=
self
.
frames
[
-
1
]
assert
isinstance
(
last_block_frame
,
BlockFrame
),
f
"Last frame must be a block frame, got
{
last_block_frame
}
"
assert
isinstance
(
last_block_frame
,
BlockFrame
),
f
"Last frame must be a block frame, got
{
last_block_frame
}
"
maybe_cpu
=
last_block_frame
.
annotations
.
get
(
"tilelang.is_cpu_kernel_frame"
,
False
)
...
...
@@ -303,56 +303,48 @@ def Kernel(
def
get_thread_binding
(
dim
:
int
=
0
)
->
Var
:
"""Returns the thread binding for the given dimension.
"""
"""Returns the thread binding for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_binding
(
dim
)
def
get_thread_bindings
()
->
list
[
Var
]:
"""Returns all three thread bindings.
"""
"""Returns all three thread bindings."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_bindings
()
def
get_block_binding
(
dim
:
int
=
0
)
->
Var
:
"""Returns the block binding for the given dimension.
"""
"""Returns the block binding for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_binding
(
dim
)
def
get_block_bindings
()
->
list
[
Var
]:
"""Returns all three block bindings.
"""
"""Returns all three block bindings."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_bindings
()
def
get_thread_extent
(
dim
:
int
=
0
)
->
int
:
"""Returns the thread extent for the given dimension.
"""
"""Returns the thread extent for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_extent
(
dim
)
def
get_thread_extents
()
->
list
[
int
]:
"""Returns all three thread extents.
"""
"""Returns all three thread extents."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_extents
()
def
get_block_extent
(
dim
:
int
=
0
)
->
int
:
"""Returns the block extent for the given dimension.
"""
"""Returns the block extent for the given dimension."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_extent
(
dim
)
def
get_block_extents
()
->
list
[
int
]:
"""Returns all three block extents.
"""
"""Returns all three block extents."""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_extents
()
tilelang/language/logical.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tilelang
import
language
as
T
...
...
@@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion):
)
new_region
.
append
(
r
.
min
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.any_of"
),
T
.
address_of
(
buffer_load
),
extent
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.any_of"
),
T
.
address_of
(
buffer_load
),
extent
)
else
:
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
...
...
@@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion):
)
new_region
.
append
(
r
.
min
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.all_of"
),
T
.
address_of
(
buffer_load
),
extent
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.all_of"
),
T
.
address_of
(
buffer_load
),
extent
)
else
:
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
tilelang/language/loop.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
typing
import
Any
from
tvm
import
tir
...
...
@@ -94,11 +95,9 @@ def Pipelined(
return
_ffi_api
.
Pipelined
(
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
def
serial
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
def
serial
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
step_is_one
=
False
step_is_one
|=
isinstance
(
step
,
int
)
and
step
==
1
step_is_one
|=
isinstance
(
step
,
IntImm
)
and
step
.
value
==
1
...
...
@@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr,
return
SerialForWithStep
(
start
,
stop
,
step
,
annotations
=
annotations
)
def
unroll
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
explicit
:
bool
=
False
,
unroll_factor
:
int
|
None
=
None
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
def
unroll
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
explicit
:
bool
=
False
,
unroll_factor
:
int
|
None
=
None
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
frame
.
ForFrame
:
"""The unrolled For statement.
Parameters
...
...
tilelang/language/math_intrinsics.py
View file @
29051439
...
...
@@ -3,7 +3,7 @@ from tvm import tir
def
_validate_rounding_mode
(
rounding_mode
):
"""Validate that the rounding mode is one of the supported IEEE modes"""
valid_modes
=
{
'
rn
'
,
'
rz
'
,
'
ru
'
,
'
rd
'
}
valid_modes
=
{
"
rn
"
,
"
rz
"
,
"
ru
"
,
"
rd
"
}
if
isinstance
(
rounding_mode
,
str
)
and
rounding_mode
in
valid_modes
:
return
raise
ValueError
(
f
"Invalid rounding mode '
{
rounding_mode
}
'. Must be one of:
{
valid_modes
}
"
)
...
...
tilelang/language/overrides/parser.py
View file @
29051439
"""TVMScript parser overrides tailored for TileLang."""
from
functools
import
partial
from
tvm.script.ir_builder
import
tir
as
T
...
...
@@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
continue
...
...
@@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
return
...
...
@@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
return
...
...
tilelang/language/parser/entry.py
View file @
29051439
...
...
@@ -18,6 +18,7 @@
# which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa
"""The entry point of TVM parser for tir."""
import
inspect
from
typing
import
Callable
,
Optional
,
Union
...
...
@@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils
from
tvm.script.parser.core.parser
import
Parser
,
ScriptMacro
def
prim_func
(
func
:
Optional
[
Callable
]
=
None
,
private
:
bool
=
False
,
check_well_formed
=
True
)
->
Union
[
PrimFunc
,
Callable
]:
def
prim_func
(
func
:
Optional
[
Callable
]
=
None
,
private
:
bool
=
False
,
check_well_formed
=
True
)
->
Union
[
PrimFunc
,
Callable
]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
...
...
@@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
if
len
(
args
)
==
1
and
inspect
.
isfunction
(
args
[
0
]):
return
_decorator
(
args
[
0
])
raise
ValueError
(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
)
raise
ValueError
(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
)
class
BufferProxy
:
...
...
tilelang/language/parser/operation.py
View file @
29051439
...
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration"""
from
tvm
import
tir
from
tvm.ffi.runtime_ctypes
import
DataType
,
DataTypeCode
from
tvm.tir
import
IntImm
...
...
@@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
return
dtype
[
0
:
index
]
def
_auto_broadcast
(
a
,
b
,
op
):
if
isinstance
(
a
,
int
):
if
hasattr
(
b
,
"dtype"
):
if
(
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
UINT
):
if
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
UINT
:
a
=
IntImm
(
_get_type_str
(
b
.
dtype
),
a
)
elif
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
a
=
FloatImm
(
_get_type_str
(
b
.
dtype
),
a
)
...
...
@@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
assert
isinstance
(
a
,
tir
.
PrimExpr
),
"Operand should be a PrimExpr."
if
isinstance
(
b
,
int
):
if
(
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
UINT
):
if
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
UINT
:
b
=
IntImm
(
_get_type_str
(
a
.
dtype
),
b
)
elif
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
b
=
FloatImm
(
_get_type_str
(
a
.
dtype
),
b
)
...
...
@@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
if
DataType
(
a
.
dtype
).
lanes
==
DataType
(
b
.
dtype
).
lanes
:
return
op
(
a
,
b
)
elif
(
DataType
(
a
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
)
:
elif
DataType
(
a
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
:
broadcast_a
=
tir
.
Broadcast
(
a
,
DataType
(
b
.
dtype
).
lanes
)
return
op
(
broadcast_a
,
b
)
elif
(
DataType
(
b
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
)
:
elif
DataType
(
b
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
:
broadcast_b
=
tir
.
Broadcast
(
b
,
DataType
(
a
.
dtype
).
lanes
)
return
op
(
a
,
broadcast_b
)
else
:
...
...
tilelang/language/parser/parser.py
View file @
29051439
...
...
@@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
res
=
value
.
__enter__
()
IRBuilder
.
name
(
var_name
,
res
)
return
res
elif
isinstance
(
value
,
(
Buffer
,
IterVar
))
or
(
isinstance
(
value
,
Var
)
and
not
self
.
var_table
.
exist
(
value
)):
elif
isinstance
(
value
,
(
Buffer
,
IterVar
))
or
(
isinstance
(
value
,
Var
)
and
not
self
.
var_table
.
exist
(
value
)):
IRBuilder
.
name
(
var_name
,
value
)
return
value
else
:
...
...
@@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None:
if
not
isinstance
(
for_frame
,
T
.
frame
.
ForFrame
):
self
.
report_error
(
node
.
iter
,
"Expect the for loop to be one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding"
,
"Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding"
,
)
with
self
.
var_table
.
with_frame
():
with
for_frame
as
iters
:
...
...
@@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None:
for
item
in
node
.
items
:
frame
=
self
.
eval_expr
(
item
.
context_expr
)
if
not
isinstance
(
frame
,
Frame
):
self
.
report_error
(
item
.
context_expr
,
"Invalid context expression in the with-statement."
)
self
.
report_error
(
item
.
context_expr
,
"Invalid context expression in the with-statement."
)
rhs
=
stack
.
enter_context
(
frame
)
if
item
.
optional_vars
is
not
None
:
self
.
eval_assign
(
target
=
item
.
optional_vars
,
source
=
rhs
,
bind_value
=
bind_with_value
)
...
...
@@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None:
with
self
.
var_table
.
with_frame
():
self
.
visit_body
(
node
.
orelse
)
else
:
self
.
report_error
(
node
.
test
,
f
"If condition must be a boolean expression, but got
{
predicate
}
"
)
self
.
report_error
(
node
.
test
,
f
"If condition must be a boolean expression, but got
{
predicate
}
"
)
@
dispatch
.
register
(
token
=
"tir"
,
type_name
=
"Assert"
)
...
...
tilelang/language/print.py
View file @
29051439
...
...
@@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
@
macro
def
print_var_with_condition
(
condition
:
tir
.
PrimExpr
,
var
:
tir
.
PrimExpr
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_var_with_condition
(
condition
:
tir
.
PrimExpr
,
var
:
tir
.
PrimExpr
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
...
...
@@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr,
@
macro
def
print_global_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_global_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
"""
...
...
@@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
else
:
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
@
macro
def
print_shared_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_shared_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
...
@@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
@
macro
def
print_fragment_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_fragment_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
...
@@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
@
macro
def
print_local_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
def
print_local_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
...
@@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
from
tilelang.utils.target
import
check_cuda_availability
...
...
@@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems
*=
dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition
=
(
tx
==
main_lane
and
ty
==
0
and
tz
==
0
)
condition
=
tx
==
main_lane
and
ty
==
0
and
tz
==
0
if
not
msg
:
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
return
print_fragment_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
...
...
@@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems
*=
dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition
=
(
tx
==
main_lane
and
ty
==
0
and
tz
==
0
)
condition
=
tx
==
main_lane
and
ty
==
0
and
tz
==
0
if
not
msg
:
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
return
print_shared_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
...
...
@@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
else
:
# Unsupported object type.
raise
ValueError
(
f
"Unexpected type:
{
type
(
obj
)
}
. Supported types are tir.Buffer and tir.PrimExpr."
)
raise
ValueError
(
f
"Unexpected type:
{
type
(
obj
)
}
. Supported types are tir.Buffer and tir.PrimExpr."
)
tilelang/language/proxy.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
from
typing
import
Any
,
SupportsIndex
,
TYPE_CHECKING
,
Generic
,
TypeVar
...
...
@@ -51,11 +52,9 @@ class BufferProxy:
return
self
(
keys
)
return
self
(
*
keys
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Buffer
:
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Buffer
:
"""Create a buffer from a pointer, shape, and data type.
Args:
...
...
@@ -76,6 +75,7 @@ class BaseTensorProxy:
customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations.
"""
default_scope
=
"global"
default_align
=
0
default_offset_factor
=
0
...
...
@@ -118,11 +118,9 @@ class BaseTensorProxy:
keys
=
(
keys
,)
return
self
(
*
keys
)
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
def
from_ptr
(
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
"""Create a buffer from a pointer, shape, and data type.
Args:
...
...
@@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy):
strides
.
append
(
s
)
return
tuple
(
reversed
(
strides
))
def
__call__
(
self
,
shape
:
tuple
[
Any
]
|
PrimExpr
|
int
,
dtype
:
str
=
"float32"
,
data
=
None
,
scope
=
None
)
->
tir
.
Buffer
:
def
__call__
(
self
,
shape
:
tuple
[
Any
]
|
PrimExpr
|
int
,
dtype
:
str
=
"float32"
,
data
=
None
,
scope
=
None
)
->
tir
.
Buffer
:
if
isinstance
(
shape
,
(
int
,
PrimExpr
)):
shape
=
(
shape
,)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
TensorProxy
.
_construct_strides
(
shape
),
data
=
data
,
scope
=
scope
)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
TensorProxy
.
_construct_strides
(
shape
),
data
=
data
,
scope
=
scope
)
class
StridedTensorProxy
(
BaseTensorProxy
):
...
...
@@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy):
This class implements the default tensor proxy with global memory scope, with the stride information required.
"""
def
__call__
(
self
,
shape
:
tuple
[
Any
],
strides
:
tuple
[
Any
],
dtype
:
str
=
"float32"
,
scope
=
None
)
->
tir
.
Buffer
:
def
__call__
(
self
,
shape
:
tuple
[
Any
],
strides
:
tuple
[
Any
],
dtype
:
str
=
"float32"
,
scope
=
None
)
->
tir
.
Buffer
:
if
len
(
shape
)
!=
len
(
strides
):
raise
ValueError
(
"Invalid shape/strides' dimensions"
)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
strides
,
scope
=
scope
)
...
...
@@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy):
This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations.
"""
default_scope
=
"local.fragment"
...
...
@@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy):
This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations.
"""
default_scope
=
"shared.dyn"
...
...
@@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy):
This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels.
"""
default_scope
=
"local"
...
...
@@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name
if
TYPE_CHECKING
:
class
BaseTensor
:
def
__class_getitem__
(
cls
,
key
):
return
cls
def
__getitem__
(
self
,
key
)
->
Any
:
...
def
__getitem__
(
self
,
key
)
->
Any
:
...
def
__setitem__
(
self
,
key
,
value
)
->
None
:
...
def
__setitem__
(
self
,
key
,
value
)
->
None
:
...
def
__init__
(
self
,
...
...
@@ -238,36 +223,26 @@ if TYPE_CHECKING:
offset_factor
=
None
,
buffer_type
=
""
,
axis_separators
=
None
,
):
...
):
...
@
classmethod
def
from_ptr
(
cls
,
pointer_var
:
Var
,
shape
:
Sequence
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Self
:
...
def
from_ptr
(
cls
,
pointer_var
:
Var
,
shape
:
Sequence
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Self
:
...
class
Tensor
(
BaseTensor
):
...
class
Tensor
(
BaseTensor
):
...
class
StridedTensor
(
BaseTensor
):
...
class
StridedTensor
(
BaseTensor
):
...
class
FragmentBuffer
(
BaseTensor
):
...
class
FragmentBuffer
(
BaseTensor
):
...
class
SharedBuffer
(
BaseTensor
):
...
class
SharedBuffer
(
BaseTensor
):
...
class
LocalBuffer
(
BaseTensor
):
...
class
LocalBuffer
(
BaseTensor
):
...
_T
=
TypeVar
(
'
_T
'
)
_T
=
TypeVar
(
"
_T
"
)
class
Ref
(
Generic
[
_T
],
tir
.
Var
):
...
class
Ref
(
Generic
[
_T
],
tir
.
Var
):
...
else
:
Tensor
=
TensorProxy
()
# pylint: disable=invalid-name
StridedTensor
=
StridedTensorProxy
()
# pylint: disable=invalid-name
...
...
@@ -275,14 +250,10 @@ else:
SharedBuffer
=
SharedBufferProxy
()
# pylint: disable=invalid-name
LocalBuffer
=
LocalBufferProxy
()
# pylint: disable=invalid-name
class
Ref
:
...
class
Ref
:
...
def
ptr
(
dtype
:
str
|
None
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
def
ptr
(
dtype
:
str
|
None
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
"""Create a TIR var that represents a pointer.
Parameters
...
...
@@ -304,8 +275,5 @@ def ptr(dtype: str | None = None,
return
handle
(
dtype
=
dtype
,
storage_scope
=
storage_scope
,
is_size_var
=
is_size_var
)
def
make_tensor
(
ptr
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
def
make_tensor
(
ptr
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
return
Tensor
.
from_ptr
(
ptr
,
shape
,
dtype
,
strides
)
Prev
1
…
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment