Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
281 additions
and
328 deletions
+281
-328
tilelang/language/ast/_ffi_api.py
tilelang/language/ast/_ffi_api.py
+1
-0
tilelang/language/ast/ir.py
tilelang/language/ast/ir.py
+39
-48
tilelang/language/atomic.py
tilelang/language/atomic.py
+7
-19
tilelang/language/builtin.py
tilelang/language/builtin.py
+48
-62
tilelang/language/copy.py
tilelang/language/copy.py
+36
-22
tilelang/language/customize.py
tilelang/language/customize.py
+6
-6
tilelang/language/experimental/gemm_sp.py
tilelang/language/experimental/gemm_sp.py
+8
-5
tilelang/language/fill.py
tilelang/language/fill.py
+3
-4
tilelang/language/frame.py
tilelang/language/frame.py
+4
-4
tilelang/language/gemm.py
tilelang/language/gemm.py
+30
-6
tilelang/language/kernel.py
tilelang/language/kernel.py
+10
-18
tilelang/language/logical.py
tilelang/language/logical.py
+3
-4
tilelang/language/loop.py
tilelang/language/loop.py
+13
-12
tilelang/language/math_intrinsics.py
tilelang/language/math_intrinsics.py
+1
-1
tilelang/language/overrides/parser.py
tilelang/language/overrides/parser.py
+19
-6
tilelang/language/parser/entry.py
tilelang/language/parser/entry.py
+3
-5
tilelang/language/parser/operation.py
tilelang/language/parser/operation.py
+5
-7
tilelang/language/parser/parser.py
tilelang/language/parser/parser.py
+4
-8
tilelang/language/print.py
tilelang/language/print.py
+11
-29
tilelang/language/proxy.py
tilelang/language/proxy.py
+30
-62
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/language/ast/_ffi_api.py
View file @
29051439
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
# which is part of the TVM project (https://tvm.apache.org/).
"""FFI APIs"""
"""FFI APIs"""
import
tvm.ffi
import
tvm.ffi
tvm
.
ffi
.
_init_api
(
"script.ir_builder.tir"
,
__name__
)
# pylint: disable=protected-access
tvm
.
ffi
.
_init_api
(
"script.ir_builder.tir"
,
__name__
)
# pylint: disable=protected-access
tilelang/language/ast/ir.py
View file @
29051439
...
@@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name
...
@@ -558,7 +558,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
The iteration variable.
"""
"""
return
_ffi_api
.
AxisSpatial
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
AxisSpatial
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
@
staticmethod
def
reduce
(
def
reduce
(
...
@@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name
...
@@ -585,7 +586,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
The iteration variable.
"""
"""
return
_ffi_api
.
AxisReduce
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
AxisReduce
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
@
staticmethod
def
scan
(
def
scan
(
...
@@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name
...
@@ -612,7 +614,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
The iteration variable.
"""
"""
return
_ffi_api
.
AxisScan
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
AxisScan
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
@
staticmethod
def
opaque
(
def
opaque
(
...
@@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name
...
@@ -639,7 +642,8 @@ class axis: # pylint: disable=invalid-name
The iteration variable.
The iteration variable.
"""
"""
return
_ffi_api
.
AxisOpaque
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
AxisOpaque
(
# type: ignore[attr-defined] # pylint: disable=no-member
_as_range
(
dom
),
binding
,
dtype
)
_as_range
(
dom
),
binding
,
dtype
)
@
staticmethod
@
staticmethod
def
remap
(
kinds
:
str
,
bindings
:
List
[
PrimExpr
],
dtype
:
str
=
"int32"
)
->
Union
[
List
[
Var
],
Var
]:
def
remap
(
kinds
:
str
,
bindings
:
List
[
PrimExpr
],
dtype
:
str
=
"int32"
)
->
Union
[
List
[
Var
],
Var
]:
...
@@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name
...
@@ -662,17 +666,15 @@ class axis: # pylint: disable=invalid-name
The iteration variables.
The iteration variables.
"""
"""
iter_vars
=
_ffi_api
.
AxisRemap
(
# type: ignore[attr-defined] # pylint: disable=no-member
iter_vars
=
_ffi_api
.
AxisRemap
(
# type: ignore[attr-defined] # pylint: disable=no-member
kinds
,
bindings
,
dtype
)
kinds
,
bindings
,
dtype
)
return
iter_vars
[
0
]
if
len
(
iter_vars
)
==
1
else
iter_vars
return
iter_vars
[
0
]
if
len
(
iter_vars
)
==
1
else
iter_vars
S
=
spatial
# pylint: disable=invalid-name
S
=
spatial
# pylint: disable=invalid-name
R
=
reduce
# pylint: disable=invalid-name
R
=
reduce
# pylint: disable=invalid-name
def
serial
(
start
:
PrimExpr
,
def
serial
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The serial For statement.
"""The serial For statement.
Parameters
Parameters
...
@@ -700,10 +702,7 @@ def serial(start: PrimExpr,
...
@@ -700,10 +702,7 @@ def serial(start: PrimExpr,
return
_ffi_api
.
Serial
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Serial
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
parallel
(
start
:
PrimExpr
,
def
parallel
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The parallel For statement.
"""The parallel For statement.
Parameters
Parameters
...
@@ -731,10 +730,7 @@ def parallel(start: PrimExpr,
...
@@ -731,10 +730,7 @@ def parallel(start: PrimExpr,
return
_ffi_api
.
Parallel
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Parallel
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
vectorized
(
start
:
PrimExpr
,
def
vectorized
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The vectorized For statement.
"""The vectorized For statement.
Parameters
Parameters
...
@@ -762,10 +758,7 @@ def vectorized(start: PrimExpr,
...
@@ -762,10 +758,7 @@ def vectorized(start: PrimExpr,
return
_ffi_api
.
Vectorized
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Vectorized
(
start
,
stop
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
unroll
(
start
:
PrimExpr
,
def
unroll
(
start
:
PrimExpr
,
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
stop
:
PrimExpr
=
None
,
*
,
annotations
:
Dict
[
str
,
Any
]
=
None
)
->
frame
.
ForFrame
:
"""The unrolled For statement.
"""The unrolled For statement.
Parameters
Parameters
...
@@ -837,7 +830,8 @@ def thread_binding(
...
@@ -837,7 +830,8 @@ def thread_binding(
else
:
else
:
start
=
0
start
=
0
return
_ffi_api
.
ThreadBinding
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
ThreadBinding
(
# type: ignore[attr-defined] # pylint: disable=no-member
start
,
stop
,
thread
,
annotations
)
start
,
stop
,
thread
,
annotations
)
def
grid
(
*
extents
:
PrimExpr
)
->
frame
.
ForFrame
:
def
grid
(
*
extents
:
PrimExpr
)
->
frame
.
ForFrame
:
...
@@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d
...
@@ -878,10 +872,10 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d
def
LetStmt
(
# pylint: disable=invalid-name
def
LetStmt
(
# pylint: disable=invalid-name
value
:
PrimExpr
,
value
:
PrimExpr
,
type_annotation
:
Optional
[
Type
]
=
None
,
# pylint: disable=redefined-outer-name
type_annotation
:
Optional
[
Type
]
=
None
,
# pylint: disable=redefined-outer-name
*
,
*
,
var
:
Optional
[
Var
]
=
None
,
# pylint: disable=redefined-outer-name
var
:
Optional
[
Var
]
=
None
,
# pylint: disable=redefined-outer-name
)
->
frame
.
LetFrame
:
)
->
frame
.
LetFrame
:
"""Create a LetStmt binding
"""Create a LetStmt binding
...
@@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name
...
@@ -909,8 +903,8 @@ def LetStmt( # pylint: disable=invalid-name
def
Let
(
# pylint: disable=invalid-name
def
Let
(
# pylint: disable=invalid-name
expr
:
PrimExpr
,
expr
:
PrimExpr
,
where
:
Dict
[
Var
,
PrimExpr
],
# pylint: disable=redefined-outer-name
where
:
Dict
[
Var
,
PrimExpr
],
# pylint: disable=redefined-outer-name
)
->
PrimExpr
:
)
->
PrimExpr
:
"""Create a Let expression binding"""
"""Create a Let expression binding"""
assert
len
(
where
)
==
1
,
"T.Let only allows `where` to have exactly one element"
assert
len
(
where
)
==
1
,
"T.Let only allows `where` to have exactly one element"
...
@@ -980,7 +974,8 @@ def realize(
...
@@ -980,7 +974,8 @@ def realize(
The result RealizeFrame.
The result RealizeFrame.
"""
"""
return
_ffi_api
.
Realize
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Realize
(
# type: ignore[attr-defined] # pylint: disable=no-member
buffer_slice
,
storage_scope
,
condition
)
buffer_slice
,
storage_scope
,
condition
)
def
allocate
(
def
allocate
(
...
@@ -1012,7 +1007,8 @@ def allocate(
...
@@ -1012,7 +1007,8 @@ def allocate(
if
isinstance
(
condition
,
bool
):
if
isinstance
(
condition
,
bool
):
condition
=
IntImm
(
"bool"
,
condition
)
condition
=
IntImm
(
"bool"
,
condition
)
return
_ffi_api
.
Allocate
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Allocate
(
# type: ignore[attr-defined] # pylint: disable=no-member
extents
,
dtype
,
scope
,
condition
,
annotations
)
extents
,
dtype
,
scope
,
condition
,
annotations
)
def
allocate_const
(
def
allocate_const
(
...
@@ -1048,7 +1044,8 @@ def allocate_const(
...
@@ -1048,7 +1044,8 @@ def allocate_const(
np_data
=
np_data
.
reshape
(
extents
)
np_data
=
np_data
.
reshape
(
extents
)
return
_ffi_api
.
AllocateConst
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
AllocateConst
(
# type: ignore[attr-defined] # pylint: disable=no-member
ndarray
.
array
(
np_data
),
dtype
,
extents
,
annotations
)
ndarray
.
array
(
np_data
),
dtype
,
extents
,
annotations
)
def
attr
(
node
:
Any
,
attr_key
:
str
,
value
:
Union
[
PrimExpr
,
str
])
->
frame
.
AttrFrame
:
def
attr
(
node
:
Any
,
attr_key
:
str
,
value
:
Union
[
PrimExpr
,
str
])
->
frame
.
AttrFrame
:
...
@@ -1297,7 +1294,8 @@ def buffer_store(
...
@@ -1297,7 +1294,8 @@ def buffer_store(
if
isinstance
(
value
,
bool
)
and
buffer
.
dtype
==
"bool"
:
if
isinstance
(
value
,
bool
)
and
buffer
.
dtype
==
"bool"
:
value
=
IntImm
(
"bool"
,
value
)
value
=
IntImm
(
"bool"
,
value
)
return
_ffi_api
.
BufferStore
(
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
BufferStore
(
# type: ignore[attr-defined] # pylint: disable=no-member
buffer
,
value
,
expr_indices
)
buffer
,
value
,
expr_indices
)
def
prefetch
(
def
prefetch
(
...
@@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
...
@@ -1464,10 +1462,7 @@ def boolean(expr: Optional[PrimExpr] = None, is_size_var: bool = False) -> PrimE
return
_ffi_api
.
Boolean
(
expr
,
is_size_var
)
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Boolean
(
expr
,
is_size_var
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
handle
(
dtype
:
Optional
[
str
]
=
None
,
def
handle
(
dtype
:
Optional
[
str
]
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
"""Create a TIR var that represents a pointer.
"""Create a TIR var that represents a pointer.
Parameters
Parameters
...
@@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
...
@@ -1667,7 +1662,7 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
res
=
combiner
(
*
args
)
res
=
combiner
(
*
args
)
if
not
isinstance
(
res
,
tuple
):
if
not
isinstance
(
res
,
tuple
):
res
=
(
res
,)
res
=
(
res
,)
return
CommReducer
(
args
[:
num_args
//
2
],
args
[
num_args
//
2
:],
res
,
identity
)
return
CommReducer
(
args
[:
num_args
//
2
],
args
[
num_args
//
2
:],
res
,
identity
)
def
index_map
(
def
index_map
(
...
@@ -1700,16 +1695,15 @@ def target(
...
@@ -1700,16 +1695,15 @@ def target(
The target.
The target.
"""
"""
if
not
isinstance
(
target_config
,
(
str
,
dict
)):
if
not
isinstance
(
target_config
,
(
str
,
dict
)):
raise
ValueError
(
raise
ValueError
(
f
"T.target expected a config dict or string, but got
{
type
(
target_config
)
}
"
)
f
"T.target expected a config dict or string, but got
{
type
(
target_config
)
}
"
)
if
host
is
not
None
and
not
isinstance
(
host
,
(
str
,
dict
,
Target
)):
if
host
is
not
None
and
not
isinstance
(
host
,
(
str
,
dict
,
Target
)):
raise
ValueError
(
"T.target expected the host to be "
raise
ValueError
(
f
"T.target expected the host to be a config dict, string, or T.target, but got
{
type
(
host
)
}
"
)
"a config dict, string, or T.target, "
f
"but got
{
type
(
host
)
}
"
)
if
isinstance
(
target_config
,
dict
)
and
"host"
in
target_config
and
host
is
not
None
:
if
isinstance
(
target_config
,
dict
)
and
"host"
in
target_config
and
host
is
not
None
:
raise
ValueError
(
"T.target expects to either receive the host "
raise
ValueError
(
"as part of the target's config dictionary, "
"T.target expects to either receive the host "
"or as a separate argument, but not both."
)
"as part of the target's config dictionary, "
"or as a separate argument, but not both."
)
return
Target
(
target_config
,
host
)
return
Target
(
target_config
,
host
)
...
@@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name
...
@@ -1742,7 +1736,6 @@ class meta_var: # pylint: disable=invalid-name
self
.
value
=
value
self
.
value
=
value
def
__iter__
(
self
):
def
__iter__
(
self
):
def
f
():
def
f
():
for
i
in
self
.
value
:
for
i
in
self
.
value
:
yield
meta_var
(
i
)
yield
meta_var
(
i
)
...
@@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name
...
@@ -1754,7 +1747,6 @@ class meta_var: # pylint: disable=invalid-name
def
_op_wrapper
(
func
):
def
_op_wrapper
(
func
):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapped
(
*
args
,
**
kwargs
):
def
wrapped
(
*
args
,
**
kwargs
):
if
"dtype"
in
kwargs
:
if
"dtype"
in
kwargs
:
...
@@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale)
...
@@ -1874,7 +1866,6 @@ vscale = _op_wrapper(_tir_op.vscale)
def
_dtype_forward
(
func
):
def
_dtype_forward
(
func
):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
def
wrapped
(
*
args
,
**
kwargs
):
def
wrapped
(
*
args
,
**
kwargs
):
if
"dtype"
in
kwargs
:
if
"dtype"
in
kwargs
:
...
...
tilelang/language/atomic.py
View file @
29051439
# Copyright (c) Tile-AI Corporation.
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
# Licensed under the MIT License.
"""Atomic operations for tilelang."""
"""Atomic operations for tilelang."""
from
__future__
import
annotations
from
__future__
import
annotations
import
tilelang.language
as
T
import
tilelang.language
as
T
...
@@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = {
...
@@ -18,10 +19,7 @@ _MEMORY_ORDER_ID_MAP = {
}
}
def
atomic_max
(
dst
:
Buffer
,
def
atomic_max
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
"""
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
Perform an atomic maximum on the value stored at dst with an optional memory-order.
...
@@ -64,10 +62,7 @@ def atomic_max(dst: Buffer,
...
@@ -64,10 +62,7 @@ def atomic_max(dst: Buffer,
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_min
(
dst
:
Buffer
,
def
atomic_min
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
"""
"""
Atomically update the value at dst to the minimum of its current value and value.
Atomically update the value at dst to the minimum of its current value and value.
...
@@ -112,11 +107,7 @@ def atomic_min(dst: Buffer,
...
@@ -112,11 +107,7 @@ def atomic_min(dst: Buffer,
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_add
(
dst
:
Buffer
,
def
atomic_add
(
dst
:
Buffer
,
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
,
use_tma
:
bool
=
False
)
->
PrimExpr
:
value
:
PrimExpr
,
memory_order
:
str
|
None
=
None
,
return_prev
:
bool
=
False
,
use_tma
:
bool
=
False
)
->
PrimExpr
:
"""
"""
Atomically add `value` into `dst`, returning a handle to the operation.
Atomically add `value` into `dst`, returning a handle to the operation.
...
@@ -191,8 +182,7 @@ def atomic_add(dst: Buffer,
...
@@ -191,8 +182,7 @@ def atomic_add(dst: Buffer,
if
memory_order
is
None
:
if
memory_order
is
None
:
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
)
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
)
else
:
else
:
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
return
T
.
call_extern
(
return_type
,
func_name
,
dst
,
value
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
_MEMORY_ORDER_ID_MAP
[
memory_order
])
if
isinstance
(
dst
,
Buffer
)
and
isinstance
(
value
,
Buffer
):
if
isinstance
(
dst
,
Buffer
)
and
isinstance
(
value
,
Buffer
):
ir
.
assert_structural_equal
(
dst
.
shape
,
value
.
shape
)
ir
.
assert_structural_equal
(
dst
.
shape
,
value
.
shape
)
...
@@ -208,14 +198,12 @@ def atomic_add(dst: Buffer,
...
@@ -208,14 +198,12 @@ def atomic_add(dst: Buffer,
# Note: tile-region-based atomic operations don't support return_prev yet
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
# This would need to be implemented in the tile runtime
if
return_prev
:
if
return_prev
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"return_prev is not supported for tile-region-based atomic operations"
)
"return_prev is not supported for tile-region-based atomic operations"
)
if
memory_order
is
None
:
if
memory_order
is
None
:
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
0
)
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
0
)
else
:
else
:
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
return
T
.
call_intrin
(
"handle"
,
op
.
Op
.
get
(
"tl.tileop.atomicadd"
),
value
,
dst
,
use_tma
,
_MEMORY_ORDER_ID_MAP
[
memory_order
])
_MEMORY_ORDER_ID_MAP
[
memory_order
])
def
atomic_addx2
(
dst
:
Buffer
,
value
:
PrimExpr
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
def
atomic_addx2
(
dst
:
Buffer
,
value
:
PrimExpr
,
return_prev
:
bool
=
False
)
->
PrimExpr
:
...
...
tilelang/language/builtin.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
...
@@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int):
...
@@ -179,38 +180,32 @@ def set_max_nreg(reg_count: int, is_inc: int):
def
inc_max_nreg
(
reg_count
:
int
):
def
inc_max_nreg
(
reg_count
:
int
):
"""Increment the maximum number of registers to use.
"""Increment the maximum number of registers to use."""
"""
return
set_max_nreg
(
reg_count
,
1
)
return
set_max_nreg
(
reg_count
,
1
)
def
dec_max_nreg
(
reg_count
:
int
):
def
dec_max_nreg
(
reg_count
:
int
):
"""Decrement the maximum number of registers to use.
"""Decrement the maximum number of registers to use."""
"""
return
set_max_nreg
(
reg_count
,
0
)
return
set_max_nreg
(
reg_count
,
0
)
def
annotate_producer_reg_dealloc
(
reg_count
:
int
=
24
):
def
annotate_producer_reg_dealloc
(
reg_count
:
int
=
24
):
"""Annotate the producer reg dealloc.
"""Annotate the producer reg dealloc."""
"""
return
dec_max_nreg
(
reg_count
)
return
dec_max_nreg
(
reg_count
)
def
annotate_consumer_reg_alloc
(
reg_count
:
int
=
240
):
def
annotate_consumer_reg_alloc
(
reg_count
:
int
=
240
):
"""Annotate the consumer reg alloc.
"""Annotate the consumer reg alloc."""
"""
return
inc_max_nreg
(
reg_count
)
return
inc_max_nreg
(
reg_count
)
def
no_set_max_nreg
():
def
no_set_max_nreg
():
"""Disable the maximum register limit setting.
"""Disable the maximum register limit setting."""
"""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.no_set_max_nreg"
))
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.no_set_max_nreg"
))
def
disable_warp_group_reg_alloc
():
def
disable_warp_group_reg_alloc
():
"""Disable the warp group reg alloc.
"""Disable the warp group reg alloc."""
"""
return
no_set_max_nreg
()
return
no_set_max_nreg
()
...
@@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int):
...
@@ -325,7 +320,9 @@ def warpgroup_wait(num_mma: int):
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.warpgroup_wait"
),
num_mma
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.warpgroup_wait"
),
num_mma
)
def
get_lane_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_lane_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the logical lane index of the calling thread within a warp.
"""Return the logical lane index of the calling thread within a warp.
Parameters
Parameters
...
@@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
...
@@ -350,7 +347,9 @@ def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_lane_idx"
),
warp_size_expr
)
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_lane_idx"
),
warp_size_expr
)
def
get_warp_idx_sync
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_warp_idx_sync
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the canonical warp index, assuming the warp's threads are converged.
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
Parameters
...
@@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
...
@@ -374,7 +373,9 @@ def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_idx_sync"
),
warp_size_expr
)
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_idx_sync"
),
warp_size_expr
)
def
get_warp_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,)
->
PrimExpr
:
def
get_warp_idx
(
warp_size
:
int
|
PrimExpr
|
None
=
None
,
)
->
PrimExpr
:
"""Return the canonical warp index without synchronizing the warp.
"""Return the canonical warp index without synchronizing the warp.
Parameters
Parameters
...
@@ -429,8 +430,7 @@ def get_warp_group_idx(
...
@@ -429,8 +430,7 @@ def get_warp_group_idx(
args
.
append
(
warp_size_expr
)
args
.
append
(
warp_size_expr
)
if
warps_per_group_expr
is
not
None
:
if
warps_per_group_expr
is
not
None
:
if
warp_size_expr
is
None
:
if
warp_size_expr
is
None
:
raise
ValueError
(
"get_warp_group_idx expects `warp_size` when specifying "
raise
ValueError
(
"get_warp_group_idx expects `warp_size` when specifying `warps_per_group`."
)
"`warps_per_group`."
)
args
.
append
(
warps_per_group_expr
)
args
.
append
(
warps_per_group_expr
)
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_group_idx"
),
*
args
)
return
tir
.
call_intrin
(
"int32"
,
tir
.
op
.
Op
.
get
(
"tl.get_warp_group_idx"
),
*
args
)
...
@@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
...
@@ -459,10 +459,9 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return
tir
.
call_intrin
(
"bool"
,
tir
.
op
.
Op
.
get
(
"tl.tl_shuffle_elect"
),
thread_extent
)
return
tir
.
call_intrin
(
"bool"
,
tir
.
op
.
Op
.
get
(
"tl.tl_shuffle_elect"
),
thread_extent
)
def
warpgroup_fence_operand
(
buffer_or_ptr
:
tir
.
Buffer
|
PrimExpr
,
def
warpgroup_fence_operand
(
offset
:
int
|
PrimExpr
=
0
,
buffer_or_ptr
:
tir
.
Buffer
|
PrimExpr
,
offset
:
int
|
PrimExpr
=
0
,
num_regs
:
int
|
PrimExpr
|
None
=
None
,
dtype
:
str
|
None
=
None
num_regs
:
int
|
PrimExpr
|
None
=
None
,
):
dtype
:
str
|
None
=
None
):
"""Insert a warpgroup fence for the destination accumulator registers.
"""Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
...
@@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
...
@@ -517,7 +516,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
data_ptr
,
convert
(
offset
),
convert
(
offset
),
convert
(
num_regs
),
convert
(
num_regs
),
))
)
)
if
isinstance
(
buffer_or_ptr
,
tir
.
Buffer
):
if
isinstance
(
buffer_or_ptr
,
tir
.
Buffer
):
data_ptr
=
buffer_or_ptr
.
data
data_ptr
=
buffer_or_ptr
.
data
...
@@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
...
@@ -531,8 +531,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
if
isinstance
(
dim
,
tir
.
IntImm
):
if
isinstance
(
dim
,
tir
.
IntImm
):
total_elems
*=
int
(
dim
)
total_elems
*=
int
(
dim
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic."
)
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic."
)
bits_per_elem
=
DataType
(
dtype
).
bits
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
elif
isinstance
(
buffer_or_ptr
,
BufferRegion
):
elif
isinstance
(
buffer_or_ptr
,
BufferRegion
):
...
@@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
...
@@ -569,9 +568,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
bits_per_elem
=
DataType
(
dtype
).
bits
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
return
evaluate
(
return
evaluate
(
tir
.
call_intrin
(
tir
.
call_intrin
(
"handle"
,
"handle"
,
...
@@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
...
@@ -580,7 +577,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
data_ptr
,
convert
(
offset
),
convert
(
offset
),
convert
(
num_regs
),
convert
(
num_regs
),
))
)
)
else
:
else
:
data_ptr
=
buffer_or_ptr
data_ptr
=
buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided
# Try to infer dtype from common pointer expressions when not provided
...
@@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
...
@@ -603,9 +601,7 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
except
Exception
:
except
Exception
:
inferred
=
None
inferred
=
None
if
inferred
is
None
:
if
inferred
is
None
:
raise
ValueError
(
raise
ValueError
(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype
=
inferred
dtype
=
inferred
if
num_regs
is
None
:
if
num_regs
is
None
:
raise
ValueError
(
"num_regs must be provided when passing a pointer expression."
)
raise
ValueError
(
"num_regs must be provided when passing a pointer expression."
)
...
@@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
...
@@ -618,7 +614,8 @@ def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
data_ptr
,
data_ptr
,
convert
(
offset
),
convert
(
offset
),
convert
(
num_regs
),
convert
(
num_regs
),
))
)
)
def
wait_wgmma
(
id
:
int
):
def
wait_wgmma
(
id
:
int
):
...
@@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call
...
@@ -673,7 +670,7 @@ def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call
if
_IS_HIP_AVAILABLE
:
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor"
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor"
,
value
,
offset
)
else
:
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_xor_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
shfl_down
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
def
shfl_down
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
...
@@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal
...
@@ -686,7 +683,7 @@ def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Cal
if
_IS_HIP_AVAILABLE
:
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down"
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down"
,
value
,
offset
)
else
:
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_down_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
shfl_up
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
def
shfl_up
(
value
:
int
|
PrimExpr
|
tir
.
Call
,
offset
:
int
|
PrimExpr
|
tir
.
Call
):
...
@@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call)
...
@@ -699,12 +696,11 @@ def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call)
if
_IS_HIP_AVAILABLE
:
if
_IS_HIP_AVAILABLE
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up"
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up"
,
value
,
offset
)
else
:
else
:
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up_sync"
,
0x
ffffffff
,
value
,
offset
)
return
tir
.
call_extern
(
value
.
dtype
,
"__shfl_up_sync"
,
0x
FFFFFFFF
,
value
,
offset
)
def
sync_threads
(
barrier_id
:
int
=
None
,
arrive_count
:
int
=
None
):
def
sync_threads
(
barrier_id
:
int
=
None
,
arrive_count
:
int
=
None
):
"""Synchronize all threads in a block.
"""Synchronize all threads in a block."""
"""
args
=
[]
args
=
[]
if
barrier_id
is
not
None
:
if
barrier_id
is
not
None
:
args
.
append
(
barrier_id
)
args
.
append
(
barrier_id
)
...
@@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
...
@@ -714,8 +710,7 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
def
sync_global
():
def
sync_global
():
"""Synchronize all threads in the entire grid.
"""Synchronize all threads in the entire grid."""
"""
tx
,
ty
,
tz
=
get_thread_bindings
()
tx
,
ty
,
tz
=
get_thread_bindings
()
ex
,
ey
,
ez
=
get_block_extents
()
ex
,
ey
,
ez
=
get_block_extents
()
print
(
tx
,
ty
,
tz
,
ex
,
ey
,
ez
)
print
(
tx
,
ty
,
tz
,
ex
,
ey
,
ez
)
...
@@ -724,8 +719,7 @@ def sync_global():
...
@@ -724,8 +719,7 @@ def sync_global():
def
sync_grid
():
def
sync_grid
():
"""Synchronize all threads in a grid.
"""Synchronize all threads in a grid."""
"""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.sync_grid"
))
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.sync_grid"
))
...
@@ -741,12 +735,10 @@ def initialize_wgmma_descriptor(
...
@@ -741,12 +735,10 @@ def initialize_wgmma_descriptor(
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
descriptor
.
shape
[
0
]
!=
1
):
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
,
[
0
])
return
evaluate
(
return
evaluate
(
tir
.
call_intrin
(
tir
.
call_intrin
(
...
@@ -757,7 +749,8 @@ def initialize_wgmma_descriptor(
...
@@ -757,7 +749,8 @@ def initialize_wgmma_descriptor(
layout_type_
,
layout_type_
,
int
(
leading_byte_offset
),
int
(
leading_byte_offset
),
int
(
stride_byte_offset
),
int
(
stride_byte_offset
),
))
)
)
def
initialize_tcgen05_descriptor
(
def
initialize_tcgen05_descriptor
(
...
@@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor(
...
@@ -774,12 +767,10 @@ def initialize_tcgen05_descriptor(
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
descriptor
.
shape
[
0
]
!=
1
):
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
,
[
0
])
return
evaluate
(
return
evaluate
(
tir
.
call_intrin
(
tir
.
call_intrin
(
...
@@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor(
...
@@ -792,7 +783,8 @@ def initialize_tcgen05_descriptor(
int
(
base_offset
),
int
(
base_offset
),
tir
.
IntImm
(
"int32"
,
1
if
leading_is_absolute
else
0
),
tir
.
IntImm
(
"int32"
,
1
if
leading_is_absolute
else
0
),
int
(
swizzle_mode
),
int
(
swizzle_mode
),
))
)
)
def
increase_descriptor_offset
(
descriptor
:
PrimExpr
,
offset
:
PrimExpr
)
->
PrimExpr
:
def
increase_descriptor_offset
(
descriptor
:
PrimExpr
,
offset
:
PrimExpr
)
->
PrimExpr
:
...
@@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
...
@@ -809,27 +801,21 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
len
(
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
,
[
0
])
return
evaluate
(
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.increase_descriptor_offset"
),
descriptor
,
offset
))
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.increase_descriptor_offset"
),
descriptor
,
offset
))
def
loop_break
():
def
loop_break
():
"""Break out of the innermost loop.
"""Break out of the innermost loop."""
"""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.loop_break"
))
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.loop_break"
))
def
cp_async_barrier_noinc
(
barrier_id
:
int
|
PrimExpr
|
tir
.
Call
):
def
cp_async_barrier_noinc
(
barrier_id
:
int
|
PrimExpr
|
tir
.
Call
):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc."""
"""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ptx_cp_async_barrier_noinc"
),
barrier_id
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ptx_cp_async_barrier_noinc"
),
barrier_id
)
...
...
tilelang/language/copy.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Literal
from
typing
import
Literal
from
tilelang
import
language
as
T
from
tilelang
import
language
as
T
...
@@ -10,11 +11,13 @@ from tilelang.utils.language import (
...
@@ -10,11 +11,13 @@ from tilelang.utils.language import (
from
tvm
import
ir
,
tir
from
tvm
import
ir
,
tir
def
copy
(
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
def
copy
(
dst
:
tir
.
Buffer
|
tir
.
BufferLoad
,
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
coalesced_width
:
int
|
None
=
None
,
dst
:
tir
.
Buffer
|
tir
.
BufferLoad
,
disable_tma
:
bool
=
False
,
coalesced_width
:
int
|
None
=
None
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
):
disable_tma
:
bool
=
False
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
,
):
"""Copy data between memory regions.
"""Copy data between memory regions.
Args:
Args:
...
@@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
...
@@ -65,8 +68,7 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
src_extent
=
get_extent
(
src
)
src_extent
=
get_extent
(
src
)
dst_extent
=
get_extent
(
dst
)
dst_extent
=
get_extent
(
dst
)
# Combine the nested if statements into a single if statement as suggested by SIM102
# Combine the nested if statements into a single if statement as suggested by SIM102
if
(
src_extent
is
None
and
dst_extent
is
None
and
isinstance
(
src
,
tir
.
BufferLoad
)
and
if
src_extent
is
None
and
dst_extent
is
None
and
isinstance
(
src
,
tir
.
BufferLoad
)
and
isinstance
(
dst
,
tir
.
BufferLoad
):
isinstance
(
dst
,
tir
.
BufferLoad
)):
# check if the case is like this:
# check if the case is like this:
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# copy(buffer_a[i], buffer_b[i]) where both are BufferLoad nodes
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
# In this case, lower it to a simple BufferStore: buffer_b[i] = buffer_a[i]
...
@@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
...
@@ -90,19 +92,20 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
eviction_policy
=
0
eviction_policy
=
0
else
:
else
:
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.copy"
),
src
,
dst
,
coalesced_width
,
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.copy"
),
src
,
dst
,
coalesced_width
,
disable_tma
,
eviction_policy
)
disable_tma
,
eviction_policy
)
def
c2d_im2col
(
def
c2d_im2col
(
img
:
tir
.
Buffer
,
img
:
tir
.
Buffer
,
col
:
tir
.
Buffer
,
col
:
tir
.
Buffer
,
nhw_step
:
tir
.
PrimExpr
,
nhw_step
:
tir
.
PrimExpr
,
c_step
:
tir
.
PrimExpr
,
c_step
:
tir
.
PrimExpr
,
kernel
:
int
,
kernel
:
int
,
stride
:
int
,
stride
:
int
,
dilation
:
int
,
dilation
:
int
,
pad
:
int
,
pad
:
int
,
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
):
eviction_policy
:
Literal
[
"evict_normal"
,
"evict_first"
,
"evict_last"
]
|
None
=
None
,
):
"""Perform im2col transformation for 2D convolution.
"""Perform im2col transformation for 2D convolution.
Args:
Args:
...
@@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer,
...
@@ -124,5 +127,16 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
eviction_policy
=
{
"evict_normal"
:
0
,
"evict_first"
:
1
,
"evict_last"
:
2
}[
eviction_policy
]
img_region
=
to_buffer_region
(
img
,
access_type
=
"r"
)
img_region
=
to_buffer_region
(
img
,
access_type
=
"r"
)
col_region
=
to_buffer_region
(
col
,
access_type
=
"w"
)
col_region
=
to_buffer_region
(
col
,
access_type
=
"w"
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.c2d_im2col"
),
img_region
,
col_region
,
return
tir
.
call_intrin
(
nhw_step
,
c_step
,
kernel
,
stride
,
dilation
,
pad
,
eviction_policy
)
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.c2d_im2col"
),
img_region
,
col_region
,
nhw_step
,
c_step
,
kernel
,
stride
,
dilation
,
pad
,
eviction_policy
,
)
tilelang/language/customize.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tvm.tir
import
PrimExpr
,
Buffer
,
op
from
tvm.tir
import
PrimExpr
,
Buffer
,
op
from
tilelang.utils.language
import
(
bits_product
,
prim_expr_equal
)
from
tilelang.utils.language
import
bits_product
,
prim_expr_equal
from
.atomic
import
atomic_max
,
atomic_min
,
atomic_add
,
atomic_addx2
,
atomic_addx4
,
atomic_load
,
atomic_store
# noqa: F401
from
.atomic
import
atomic_max
,
atomic_min
,
atomic_add
,
atomic_addx2
,
atomic_addx4
,
atomic_load
,
atomic_store
# noqa: F401
...
@@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
...
@@ -46,9 +47,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns:
Returns:
Buffer: A new buffer view with the specified shape
Buffer: A new buffer view with the specified shape
"""
"""
assert
prim_expr_equal
(
assert
prim_expr_equal
(
bits_product
(
shape
,
src
.
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
(
bits_product
(
shape
,
src
.
dty
pe
)
,
bits_product
(
src
.
shape
,
src
.
dtype
)
f
"T.reshape/view shape check failed. src
{
src
}
src.
shape
:
{
src
.
sha
pe
}
,
src.dtype:
{
src
.
dtype
}
, target shape:
{
shape
}
, target dtype:
{
src
.
dtype
}
"
)
,
f
"T.reshape/view shape check failed. src
{
src
}
src.shape:
{
src
.
shape
}
, src.dtype:
{
src
.
dtype
}
, target shape:
{
shape
}
, target dtype:
{
src
.
dtype
}
"
)
return
T
.
Tensor
(
shape
,
src
.
dtype
,
src
.
data
)
return
T
.
Tensor
(
shape
,
src
.
dtype
,
src
.
data
)
...
@@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N
...
@@ -61,8 +62,7 @@ def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = N
shape
=
src
.
shape
shape
=
src
.
shape
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
src
.
dtype
dtype
=
src
.
dtype
assert
prim_expr_equal
(
bits_product
(
shape
,
dtype
),
assert
prim_expr_equal
(
bits_product
(
shape
,
dtype
),
bits_product
(
src
.
shape
,
src
.
dtype
)),
"T.reshape/view shape check failed."
bits_product
(
src
.
shape
,
src
.
dtype
)),
"T.reshape/view shape check failed."
return
T
.
Tensor
(
shape
,
dtype
,
src
.
data
)
return
T
.
Tensor
(
shape
,
dtype
,
src
.
data
)
...
...
tilelang/language/experimental/gemm_sp.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
import
tilelang.language
as
T
import
tilelang.language
as
T
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal
,
prim_expr_equal
,
)
)
from
tilelang.language.utils
import
(
from
tilelang.language.utils
import
(
buffer_region_to_tile_region
,)
buffer_region_to_tile_region
,
)
def
gemm_sp
(
def
gemm_sp
(
...
@@ -169,18 +171,19 @@ def gemm_sp_v2(
...
@@ -169,18 +171,19 @@ def gemm_sp_v2(
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
if
len
(
A_shape
)
>
2
:
if
len
(
A_shape
)
>
2
:
for
i
in
range
(
len
(
A_shape
)
-
2
):
for
i
in
range
(
len
(
A_shape
)
-
2
):
assert
A_shape
[
i
]
==
1
,
\
assert
A_shape
[
i
]
==
1
,
(
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if
len
(
B_shape
)
>
2
:
if
len
(
B_shape
)
>
2
:
for
i
in
range
(
len
(
B_shape
)
-
2
):
for
i
in
range
(
len
(
B_shape
)
-
2
):
assert
B_shape
[
i
]
==
1
,
\
assert
B_shape
[
i
]
==
1
,
(
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M
,
N
=
C_shape
M
,
N
=
C_shape
K
=
2
*
(
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
])
K
=
2
*
(
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
])
K_B
=
B_shape
[
-
1
]
if
transpose_B
else
B_shape
[
-
2
]
K_B
=
B_shape
[
-
1
]
if
transpose_B
else
B_shape
[
-
2
]
assert
prim_expr_equal
(
assert
prim_expr_equal
(
K
,
K_B
),
f
"T.gemm_sp K shape check failed: K_A (wo sparse) =
{
K
}
, K_B =
{
K_B
}
"
K
,
K_B
),
f
"T.gemm_sp K shape check failed: K_A (wo sparse) =
{
K
}
, K_B =
{
K_B
}
"
stride_a
=
A_stride
[
-
2
]
stride_a
=
A_stride
[
-
2
]
stride_b
=
B_stride
[
-
2
]
stride_b
=
B_stride
[
-
2
]
...
...
tilelang/language/fill.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
tvm
import
tir
from
tvm
import
tir
from
tilelang.language
import
has_let_value
,
get_let_value
from
tilelang.language
import
has_let_value
,
get_let_value
...
@@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
...
@@ -32,8 +33,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim
extents
=
[
tir
.
IntImm
(
"int32"
,
1
)
for
_
in
buffer
.
indices
]
extents
=
[
tir
.
IntImm
(
"int32"
,
1
)
for
_
in
buffer
.
indices
]
else
:
else
:
extents
=
[]
extents
=
[]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.fill"
),
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.tileop.fill"
),
to_buffer_region
(
buffer
,
access_type
=
"w"
,
extents
=
extents
),
value
)
to_buffer_region
(
buffer
,
access_type
=
"w"
,
extents
=
extents
),
value
)
def
clear
(
buffer
:
tir
.
Buffer
|
tir
.
Var
):
def
clear
(
buffer
:
tir
.
Buffer
|
tir
.
Var
):
...
@@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var):
...
@@ -55,8 +55,7 @@ def clear(buffer: tir.Buffer | tir.Var):
elif
isinstance
(
buffer_region
,
tir
.
BufferLoad
):
elif
isinstance
(
buffer_region
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
buffer_region
)
region
=
get_buffer_region_from_load
(
buffer_region
)
if
region
is
None
:
if
region
is
None
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
return
fill
(
region
,
0
)
return
fill
(
region
,
0
)
else
:
else
:
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
raise
ValueError
(
f
"Invalid buffer region:
{
buffer_region
}
, type:
{
type
(
buffer_region
)
}
"
)
...
...
tilelang/language/frame.py
View file @
29051439
"""Override the LetFrame to print a message when entering the frame."""
"""Override the LetFrame to print a message when entering the frame."""
from
__future__
import
annotations
from
__future__
import
annotations
from
tvm.ffi
import
register_object
as
_register_object
from
tvm.ffi
import
register_object
as
_register_object
from
tvm.tir
import
Var
,
PrimExpr
,
BufferLoad
,
BufferRegion
from
tvm.tir
import
Var
,
PrimExpr
,
BufferLoad
,
BufferRegion
...
@@ -29,7 +30,7 @@ class FrameStack:
...
@@ -29,7 +30,7 @@ class FrameStack:
item: The frame object to push onto the stack
item: The frame object to push onto the stack
"""
"""
self
.
_stack
.
append
(
item
)
self
.
_stack
.
append
(
item
)
if
hasattr
(
item
,
'
var
'
)
and
hasattr
(
item
,
'
value
'
):
if
hasattr
(
item
,
"
var
"
)
and
hasattr
(
item
,
"
value
"
):
self
.
_var_value_map
[
item
.
var
]
=
item
.
value
self
.
_var_value_map
[
item
.
var
]
=
item
.
value
def
pop
(
self
):
def
pop
(
self
):
...
@@ -43,7 +44,7 @@ class FrameStack:
...
@@ -43,7 +44,7 @@ class FrameStack:
"""
"""
if
self
.
_stack
:
if
self
.
_stack
:
item
=
self
.
_stack
.
pop
()
item
=
self
.
_stack
.
pop
()
if
hasattr
(
item
,
'
var
'
):
if
hasattr
(
item
,
"
var
"
):
self
.
_var_value_map
.
pop
(
item
.
var
,
None
)
self
.
_var_value_map
.
pop
(
item
.
var
,
None
)
return
item
return
item
raise
IndexError
(
f
"
{
self
.
__class__
.
__name__
}
is empty"
)
raise
IndexError
(
f
"
{
self
.
__class__
.
__name__
}
is empty"
)
...
@@ -129,8 +130,7 @@ class LetFrame(TIRFrame):
...
@@ -129,8 +130,7 @@ class LetFrame(TIRFrame):
is_block_load
=
True
is_block_load
=
True
break
break
if
is_block_load
:
if
is_block_load
:
self
.
value
=
BufferRegion
(
self
.
value
.
buffer
,
self
.
value
=
BufferRegion
(
self
.
value
.
buffer
,
[
Range
(
x
.
base
,
x
.
lanes
)
for
x
in
indices
])
[
Range
(
x
.
base
,
x
.
lanes
)
for
x
in
indices
])
_get_let_stack
().
push
(
self
)
_get_let_stack
().
push
(
self
)
return
self
.
var
return
self
.
var
...
...
tilelang/language/gemm.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
import
tilelang.language
as
T
import
tilelang.language
as
T
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
...
@@ -11,7 +12,8 @@ from tilelang.utils.language import (
prim_expr_equal
,
prim_expr_equal
,
)
)
from
tilelang.language.utils
import
(
from
tilelang.language.utils
import
(
buffer_region_to_tile_region
,)
buffer_region_to_tile_region
,
)
from
tilelang.env
import
env
as
_env
from
tilelang.env
import
env
as
_env
...
@@ -68,12 +70,14 @@ def _gemm_impl(
...
@@ -68,12 +70,14 @@ def _gemm_impl(
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
if
len
(
A_shape
)
>
2
:
if
len
(
A_shape
)
>
2
:
for
i
in
range
(
len
(
A_shape
)
-
2
):
for
i
in
range
(
len
(
A_shape
)
-
2
):
assert
A_shape
[
i
]
==
1
,
\
assert
A_shape
[
i
]
==
1
,
(
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
if
len
(
B_shape
)
>
2
:
if
len
(
B_shape
)
>
2
:
for
i
in
range
(
len
(
B_shape
)
-
2
):
for
i
in
range
(
len
(
B_shape
)
-
2
):
assert
B_shape
[
i
]
==
1
,
\
assert
B_shape
[
i
]
==
1
,
(
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
)
M
,
N
=
C_shape
M
,
N
=
C_shape
K
=
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
]
K
=
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
]
...
@@ -96,9 +100,29 @@ def _gemm_impl(
...
@@ -96,9 +100,29 @@ def _gemm_impl(
A_arg
=
buffer_region_to_tile_region
(
A_region
,
"r"
,
[
r
for
r
in
A_shape
])
A_arg
=
buffer_region_to_tile_region
(
A_region
,
"r"
,
[
r
for
r
in
A_shape
])
B_arg
=
buffer_region_to_tile_region
(
B_region
,
"r"
,
[
r
for
r
in
B_shape
])
B_arg
=
buffer_region_to_tile_region
(
B_region
,
"r"
,
[
r
for
r
in
B_shape
])
C_arg
=
buffer_region_to_tile_region
(
C_region
,
"rw"
,
[
r
for
r
in
C_shape
])
C_arg
=
buffer_region_to_tile_region
(
C_region
,
"rw"
,
[
r
for
r
in
C_shape
])
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
op_key
),
A_arg
,
B_arg
,
C_arg
,
transpose_A
,
return
tir
.
call_intrin
(
transpose_B
,
M
,
N
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
"handle"
,
offset_b
,
k_pack
,
wg_wait
,
mbar
,
C_coords
[
0
],
C_coords
[
1
])
tir
.
op
.
Op
.
get
(
op_key
),
A_arg
,
B_arg
,
C_arg
,
transpose_A
,
transpose_B
,
M
,
N
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
offset_b
,
k_pack
,
wg_wait
,
mbar
,
C_coords
[
0
],
C_coords
[
1
],
)
# Public wrappers
# Public wrappers
...
...
tilelang/language/kernel.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
collections
import
deque
from
collections
import
deque
from
tvm
import
tir
from
tvm
import
tir
...
@@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame):
...
@@ -107,8 +108,7 @@ class KernelLaunchFrame(TIRFrame):
_get_current_stack
().
push
(
self
)
_get_current_stack
().
push
(
self
)
last_block_frame
=
self
.
frames
[
-
1
]
last_block_frame
=
self
.
frames
[
-
1
]
assert
isinstance
(
last_block_frame
,
assert
isinstance
(
last_block_frame
,
BlockFrame
),
f
"Last frame must be a block frame, got
{
last_block_frame
}
"
BlockFrame
),
f
"Last frame must be a block frame, got
{
last_block_frame
}
"
maybe_cpu
=
last_block_frame
.
annotations
.
get
(
"tilelang.is_cpu_kernel_frame"
,
False
)
maybe_cpu
=
last_block_frame
.
annotations
.
get
(
"tilelang.is_cpu_kernel_frame"
,
False
)
...
@@ -303,56 +303,48 @@ def Kernel(
...
@@ -303,56 +303,48 @@ def Kernel(
def
get_thread_binding
(
dim
:
int
=
0
)
->
Var
:
def
get_thread_binding
(
dim
:
int
=
0
)
->
Var
:
"""Returns the thread binding for the given dimension.
"""Returns the thread binding for the given dimension."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_binding
(
dim
)
return
KernelLaunchFrame
.
Current
().
get_thread_binding
(
dim
)
def
get_thread_bindings
()
->
list
[
Var
]:
def
get_thread_bindings
()
->
list
[
Var
]:
"""Returns all three thread bindings.
"""Returns all three thread bindings."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_bindings
()
return
KernelLaunchFrame
.
Current
().
get_thread_bindings
()
def
get_block_binding
(
dim
:
int
=
0
)
->
Var
:
def
get_block_binding
(
dim
:
int
=
0
)
->
Var
:
"""Returns the block binding for the given dimension.
"""Returns the block binding for the given dimension."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_binding
(
dim
)
return
KernelLaunchFrame
.
Current
().
get_block_binding
(
dim
)
def
get_block_bindings
()
->
list
[
Var
]:
def
get_block_bindings
()
->
list
[
Var
]:
"""Returns all three block bindings.
"""Returns all three block bindings."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_bindings
()
return
KernelLaunchFrame
.
Current
().
get_block_bindings
()
def
get_thread_extent
(
dim
:
int
=
0
)
->
int
:
def
get_thread_extent
(
dim
:
int
=
0
)
->
int
:
"""Returns the thread extent for the given dimension.
"""Returns the thread extent for the given dimension."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_extent
(
dim
)
return
KernelLaunchFrame
.
Current
().
get_thread_extent
(
dim
)
def
get_thread_extents
()
->
list
[
int
]:
def
get_thread_extents
()
->
list
[
int
]:
"""Returns all three thread extents.
"""Returns all three thread extents."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_thread_extents
()
return
KernelLaunchFrame
.
Current
().
get_thread_extents
()
def
get_block_extent
(
dim
:
int
=
0
)
->
int
:
def
get_block_extent
(
dim
:
int
=
0
)
->
int
:
"""Returns the block extent for the given dimension.
"""Returns the block extent for the given dimension."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_extent
(
dim
)
return
KernelLaunchFrame
.
Current
().
get_block_extent
(
dim
)
def
get_block_extents
()
->
list
[
int
]:
def
get_block_extents
()
->
list
[
int
]:
"""Returns all three block extents.
"""Returns all three block extents."""
"""
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
assert
KernelLaunchFrame
.
Current
()
is
not
None
,
"KernelLaunchFrame is not initialized"
return
KernelLaunchFrame
.
Current
().
get_block_extents
()
return
KernelLaunchFrame
.
Current
().
get_block_extents
()
tilelang/language/logical.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
tilelang
import
language
as
T
from
tilelang
import
language
as
T
...
@@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion):
...
@@ -36,8 +37,7 @@ def any_of(buffer: T.Tensor | BufferRegion):
)
)
new_region
.
append
(
r
.
min
)
new_region
.
append
(
r
.
min
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.any_of"
),
T
.
address_of
(
buffer_load
),
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.any_of"
),
T
.
address_of
(
buffer_load
),
extent
)
extent
)
else
:
else
:
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
...
@@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion):
...
@@ -71,7 +71,6 @@ def all_of(buffer: T.Tensor | BufferRegion):
)
)
new_region
.
append
(
r
.
min
)
new_region
.
append
(
r
.
min
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
buffer_load
=
BufferLoad
(
buffer
,
new_region
)
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.all_of"
),
T
.
address_of
(
buffer_load
),
return
T
.
call_intrin
(
return_type
,
tir
.
op
.
Op
.
get
(
"tl.all_of"
),
T
.
address_of
(
buffer_load
),
extent
)
extent
)
else
:
else
:
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
raise
ValueError
(
f
"Invalid buffer type:
{
type
(
buffer
)
}
"
)
tilelang/language/loop.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
from
typing
import
Any
from
tvm
import
tir
from
tvm
import
tir
...
@@ -94,11 +95,9 @@ def Pipelined(
...
@@ -94,11 +95,9 @@ def Pipelined(
return
_ffi_api
.
Pipelined
(
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
return
_ffi_api
.
Pipelined
(
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
def
serial
(
start
:
tir
.
PrimExpr
,
def
serial
(
stop
:
tir
.
PrimExpr
|
None
=
None
,
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
step
:
tir
.
PrimExpr
|
None
=
None
,
)
->
frame
.
ForFrame
:
*
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
step_is_one
=
False
step_is_one
=
False
step_is_one
|=
isinstance
(
step
,
int
)
and
step
==
1
step_is_one
|=
isinstance
(
step
,
int
)
and
step
==
1
step_is_one
|=
isinstance
(
step
,
IntImm
)
and
step
.
value
==
1
step_is_one
|=
isinstance
(
step
,
IntImm
)
and
step
.
value
==
1
...
@@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr,
...
@@ -111,13 +110,15 @@ def serial(start: tir.PrimExpr,
return
SerialForWithStep
(
start
,
stop
,
step
,
annotations
=
annotations
)
return
SerialForWithStep
(
start
,
stop
,
step
,
annotations
=
annotations
)
def
unroll
(
start
:
tir
.
PrimExpr
,
def
unroll
(
stop
:
tir
.
PrimExpr
|
None
=
None
,
start
:
tir
.
PrimExpr
,
step
:
tir
.
PrimExpr
|
None
=
None
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
*
,
step
:
tir
.
PrimExpr
|
None
=
None
,
explicit
:
bool
=
False
,
*
,
unroll_factor
:
int
|
None
=
None
,
explicit
:
bool
=
False
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
)
->
frame
.
ForFrame
:
unroll_factor
:
int
|
None
=
None
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
,
)
->
frame
.
ForFrame
:
"""The unrolled For statement.
"""The unrolled For statement.
Parameters
Parameters
...
...
tilelang/language/math_intrinsics.py
View file @
29051439
...
@@ -3,7 +3,7 @@ from tvm import tir
...
@@ -3,7 +3,7 @@ from tvm import tir
def
_validate_rounding_mode
(
rounding_mode
):
def
_validate_rounding_mode
(
rounding_mode
):
"""Validate that the rounding mode is one of the supported IEEE modes"""
"""Validate that the rounding mode is one of the supported IEEE modes"""
valid_modes
=
{
'
rn
'
,
'
rz
'
,
'
ru
'
,
'
rd
'
}
valid_modes
=
{
"
rn
"
,
"
rz
"
,
"
ru
"
,
"
rd
"
}
if
isinstance
(
rounding_mode
,
str
)
and
rounding_mode
in
valid_modes
:
if
isinstance
(
rounding_mode
,
str
)
and
rounding_mode
in
valid_modes
:
return
return
raise
ValueError
(
f
"Invalid rounding mode '
{
rounding_mode
}
'. Must be one of:
{
valid_modes
}
"
)
raise
ValueError
(
f
"Invalid rounding mode '
{
rounding_mode
}
'. Must be one of:
{
valid_modes
}
"
)
...
...
tilelang/language/overrides/parser.py
View file @
29051439
"""TVMScript parser overrides tailored for TileLang."""
"""TVMScript parser overrides tailored for TileLang."""
from
functools
import
partial
from
functools
import
partial
from
tvm.script.ir_builder
import
tir
as
T
from
tvm.script.ir_builder
import
tir
as
T
...
@@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un
...
@@ -58,8 +59,12 @@ def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=un
lhs
.
ctx
=
load_ctx
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
if
(
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
continue
continue
...
@@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis
...
@@ -106,8 +111,12 @@ def tilelang_visit_aug_assign(self, node: doc.AugAssign) -> None: # pylint: dis
lhs
.
ctx
=
load_ctx
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
if
(
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
return
return
...
@@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis
...
@@ -131,8 +140,12 @@ def tilelang_visit_ann_assign(self, node: doc.AnnAssign) -> None: # pylint: dis
lhs
.
ctx
=
load_ctx
lhs
.
ctx
=
load_ctx
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs_value
=
self
.
eval_expr
(
lhs
)
lhs
.
ctx
=
store_ctx
lhs
.
ctx
=
store_ctx
if
(
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
if
(
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
isinstance
(
lhs_value
,
BufferLoad
)
and
lhs_value
.
buffer
.
scope
()
==
"local.var"
and
len
(
lhs_value
.
indices
)
==
1
and
lhs_value
.
indices
[
0
]
==
0
):
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
T
.
buffer_store
(
lhs_value
.
buffer
,
rhs
,
indices
=
[
0
])
return
return
...
...
tilelang/language/parser/entry.py
View file @
29051439
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
# which is part of the TVM project (https://tvm.apache.org/).
# which is part of the TVM project (https://tvm.apache.org/).
# ruff: noqa
# ruff: noqa
"""The entry point of TVM parser for tir."""
"""The entry point of TVM parser for tir."""
import
inspect
import
inspect
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Callable
,
Optional
,
Union
...
@@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils
...
@@ -29,9 +30,7 @@ from tvm.script.parser._core import parse, scan_macro, utils
from
tvm.script.parser.core.parser
import
Parser
,
ScriptMacro
from
tvm.script.parser.core.parser
import
Parser
,
ScriptMacro
def
prim_func
(
func
:
Optional
[
Callable
]
=
None
,
def
prim_func
(
func
:
Optional
[
Callable
]
=
None
,
private
:
bool
=
False
,
check_well_formed
=
True
)
->
Union
[
PrimFunc
,
Callable
]:
private
:
bool
=
False
,
check_well_formed
=
True
)
->
Union
[
PrimFunc
,
Callable
]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
Parameters
...
@@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
...
@@ -149,8 +148,7 @@ def macro(*args, hygienic: bool = True) -> Callable:
if
len
(
args
)
==
1
and
inspect
.
isfunction
(
args
[
0
]):
if
len
(
args
)
==
1
and
inspect
.
isfunction
(
args
[
0
]):
return
_decorator
(
args
[
0
])
return
_decorator
(
args
[
0
])
raise
ValueError
(
raise
ValueError
(
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
)
"Invalid use of T.macro. Usage: @T.macro, @T.macro(), @T.macro(hygienic=[True|False])"
)
class
BufferProxy
:
class
BufferProxy
:
...
...
tilelang/language/parser/operation.py
View file @
29051439
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
# which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration"""
"""The tir expression operation registration"""
from
tvm
import
tir
from
tvm
import
tir
from
tvm.ffi.runtime_ctypes
import
DataType
,
DataTypeCode
from
tvm.ffi.runtime_ctypes
import
DataType
,
DataTypeCode
from
tvm.tir
import
IntImm
from
tvm.tir
import
IntImm
...
@@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
...
@@ -55,11 +56,9 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
return
dtype
[
0
:
index
]
return
dtype
[
0
:
index
]
def
_auto_broadcast
(
a
,
b
,
op
):
def
_auto_broadcast
(
a
,
b
,
op
):
if
isinstance
(
a
,
int
):
if
isinstance
(
a
,
int
):
if
hasattr
(
b
,
"dtype"
):
if
hasattr
(
b
,
"dtype"
):
if
(
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
if
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
UINT
:
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
UINT
):
a
=
IntImm
(
_get_type_str
(
b
.
dtype
),
a
)
a
=
IntImm
(
_get_type_str
(
b
.
dtype
),
a
)
elif
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
elif
DataType
(
b
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
a
=
FloatImm
(
_get_type_str
(
b
.
dtype
),
a
)
a
=
FloatImm
(
_get_type_str
(
b
.
dtype
),
a
)
...
@@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
...
@@ -75,8 +74,7 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
assert
isinstance
(
a
,
tir
.
PrimExpr
),
"Operand should be a PrimExpr."
assert
isinstance
(
a
,
tir
.
PrimExpr
),
"Operand should be a PrimExpr."
if
isinstance
(
b
,
int
):
if
isinstance
(
b
,
int
):
if
(
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
if
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
INT
or
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
UINT
:
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
UINT
):
b
=
IntImm
(
_get_type_str
(
a
.
dtype
),
b
)
b
=
IntImm
(
_get_type_str
(
a
.
dtype
),
b
)
elif
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
elif
DataType
(
a
.
dtype
).
type_code
==
DataTypeCode
.
FLOAT
:
b
=
FloatImm
(
_get_type_str
(
a
.
dtype
),
b
)
b
=
FloatImm
(
_get_type_str
(
a
.
dtype
),
b
)
...
@@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
...
@@ -85,10 +83,10 @@ def _register_expr_op(ty: type): # pylint: disable=invalid-name
if
DataType
(
a
.
dtype
).
lanes
==
DataType
(
b
.
dtype
).
lanes
:
if
DataType
(
a
.
dtype
).
lanes
==
DataType
(
b
.
dtype
).
lanes
:
return
op
(
a
,
b
)
return
op
(
a
,
b
)
elif
(
DataType
(
a
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
)
:
elif
DataType
(
a
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
:
broadcast_a
=
tir
.
Broadcast
(
a
,
DataType
(
b
.
dtype
).
lanes
)
broadcast_a
=
tir
.
Broadcast
(
a
,
DataType
(
b
.
dtype
).
lanes
)
return
op
(
broadcast_a
,
b
)
return
op
(
broadcast_a
,
b
)
elif
(
DataType
(
b
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
)
:
elif
DataType
(
b
.
dtype
).
lanes
==
1
and
DataType
(
a
.
dtype
).
lanes
!=
DataType
(
b
.
dtype
).
lanes
:
broadcast_b
=
tir
.
Broadcast
(
b
,
DataType
(
a
.
dtype
).
lanes
)
broadcast_b
=
tir
.
Broadcast
(
b
,
DataType
(
a
.
dtype
).
lanes
)
return
op
(
a
,
broadcast_b
)
return
op
(
a
,
broadcast_b
)
else
:
else
:
...
...
tilelang/language/parser/parser.py
View file @
29051439
...
@@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
...
@@ -146,8 +146,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
res
=
value
.
__enter__
()
res
=
value
.
__enter__
()
IRBuilder
.
name
(
var_name
,
res
)
IRBuilder
.
name
(
var_name
,
res
)
return
res
return
res
elif
isinstance
(
value
,
(
Buffer
,
IterVar
))
or
(
isinstance
(
value
,
Var
)
and
elif
isinstance
(
value
,
(
Buffer
,
IterVar
))
or
(
isinstance
(
value
,
Var
)
and
not
self
.
var_table
.
exist
(
value
)):
not
self
.
var_table
.
exist
(
value
)):
IRBuilder
.
name
(
var_name
,
value
)
IRBuilder
.
name
(
var_name
,
value
)
return
value
return
value
else
:
else
:
...
@@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None:
...
@@ -191,8 +190,7 @@ def visit_for(self: Parser, node: doc.For) -> None:
if
not
isinstance
(
for_frame
,
T
.
frame
.
ForFrame
):
if
not
isinstance
(
for_frame
,
T
.
frame
.
ForFrame
):
self
.
report_error
(
self
.
report_error
(
node
.
iter
,
node
.
iter
,
"Expect the for loop to be one of the following: "
"Expect the for loop to be one of the following: range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding"
,
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding"
,
)
)
with
self
.
var_table
.
with_frame
():
with
self
.
var_table
.
with_frame
():
with
for_frame
as
iters
:
with
for_frame
as
iters
:
...
@@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None:
...
@@ -361,8 +359,7 @@ def visit_with(self: Parser, node: doc.With) -> None:
for
item
in
node
.
items
:
for
item
in
node
.
items
:
frame
=
self
.
eval_expr
(
item
.
context_expr
)
frame
=
self
.
eval_expr
(
item
.
context_expr
)
if
not
isinstance
(
frame
,
Frame
):
if
not
isinstance
(
frame
,
Frame
):
self
.
report_error
(
item
.
context_expr
,
self
.
report_error
(
item
.
context_expr
,
"Invalid context expression in the with-statement."
)
"Invalid context expression in the with-statement."
)
rhs
=
stack
.
enter_context
(
frame
)
rhs
=
stack
.
enter_context
(
frame
)
if
item
.
optional_vars
is
not
None
:
if
item
.
optional_vars
is
not
None
:
self
.
eval_assign
(
target
=
item
.
optional_vars
,
source
=
rhs
,
bind_value
=
bind_with_value
)
self
.
eval_assign
(
target
=
item
.
optional_vars
,
source
=
rhs
,
bind_value
=
bind_with_value
)
...
@@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None:
...
@@ -505,8 +502,7 @@ def visit_if(self: Parser, node: doc.If) -> None:
with
self
.
var_table
.
with_frame
():
with
self
.
var_table
.
with_frame
():
self
.
visit_body
(
node
.
orelse
)
self
.
visit_body
(
node
.
orelse
)
else
:
else
:
self
.
report_error
(
node
.
test
,
self
.
report_error
(
node
.
test
,
f
"If condition must be a boolean expression, but got
{
predicate
}
"
)
f
"If condition must be a boolean expression, but got
{
predicate
}
"
)
@
dispatch
.
register
(
token
=
"tir"
,
type_name
=
"Assert"
)
@
dispatch
.
register
(
token
=
"tir"
,
type_name
=
"Assert"
)
...
...
tilelang/language/print.py
View file @
29051439
...
@@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
...
@@ -26,9 +26,7 @@ def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
@
macro
@
macro
def
print_var_with_condition
(
condition
:
tir
.
PrimExpr
,
def
print_var_with_condition
(
condition
:
tir
.
PrimExpr
,
var
:
tir
.
PrimExpr
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
var
:
tir
.
PrimExpr
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
...
@@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr,
...
@@ -44,10 +42,7 @@ def print_var_with_condition(condition: tir.PrimExpr,
@
macro
@
macro
def
print_global_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
def
print_global_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Conditionally prints the values of a flattened TIR buffer if the condition is True.
"""
"""
...
@@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr,
...
@@ -55,17 +50,13 @@ def print_global_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
buffer
[
coords
])
else
:
else
:
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
@
macro
@
macro
def
print_shared_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
def
print_shared_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
@@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
...
@@ -81,15 +72,11 @@ def print_shared_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
buffer
[
coords
])
@
macro
@
macro
def
print_fragment_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
def
print_fragment_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
@@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
...
@@ -111,10 +98,7 @@ def print_fragment_buffer_with_condition(condition: tir.PrimExpr,
@
macro
@
macro
def
print_local_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
def
print_local_buffer_with_condition
(
condition
:
tir
.
PrimExpr
,
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
buffer
:
tir
.
Buffer
,
elems
:
int
,
msg
:
str
=
""
)
->
tir
.
PrimExpr
:
"""
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Conditionally prints the values of a flattened TIR buffer if the condition is True.
...
@@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
...
@@ -130,8 +114,7 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
# Iterate through the buffer elements and print each one.
# Iterate through the buffer elements and print each one.
for
i
in
serial
(
elems
):
for
i
in
serial
(
elems
):
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
coords
=
index_to_coordinates
(
i
,
buffer
.
shape
)
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
tir
.
call_extern
(
"handle"
,
"debug_print_buffer_value"
,
msg
,
buffer
.
name
,
i
,
buffer
[
coords
])
buffer
[
coords
])
from
tilelang.utils.target
import
check_cuda_availability
from
tilelang.utils.target
import
check_cuda_availability
...
@@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
...
@@ -201,7 +184,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems
*=
dim
elems
*=
dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition
=
(
tx
==
main_lane
and
ty
==
0
and
tz
==
0
)
condition
=
tx
==
main_lane
and
ty
==
0
and
tz
==
0
if
not
msg
:
if
not
msg
:
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
return
print_fragment_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
return
print_fragment_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
...
@@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
...
@@ -212,7 +195,7 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
elems
*=
dim
elems
*=
dim
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition
=
(
tx
==
main_lane
and
ty
==
0
and
tz
==
0
)
condition
=
tx
==
main_lane
and
ty
==
0
and
tz
==
0
if
not
msg
:
if
not
msg
:
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
msg
=
f
"buffer<
{
buffer
.
name
}
,
{
buffer
.
dtype
}
>"
return
print_shared_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
return
print_shared_buffer_with_condition
(
condition
,
buffer
,
elems
,
msg
)
...
@@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
...
@@ -234,5 +217,4 @@ def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) ->
else
:
else
:
# Unsupported object type.
# Unsupported object type.
raise
ValueError
(
raise
ValueError
(
f
"Unexpected type:
{
type
(
obj
)
}
. Supported types are tir.Buffer and tir.PrimExpr."
)
f
"Unexpected type:
{
type
(
obj
)
}
. Supported types are tir.Buffer and tir.PrimExpr."
)
tilelang/language/proxy.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
,
SupportsIndex
,
TYPE_CHECKING
,
Generic
,
TypeVar
from
typing
import
Any
,
SupportsIndex
,
TYPE_CHECKING
,
Generic
,
TypeVar
...
@@ -51,11 +52,9 @@ class BufferProxy:
...
@@ -51,11 +52,9 @@ class BufferProxy:
return
self
(
keys
)
return
self
(
keys
)
return
self
(
*
keys
)
# type: ignore[attr-defined] # pylint: disable=no-member
return
self
(
*
keys
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
from_ptr
(
self
,
def
from_ptr
(
pointer_var
:
Var
,
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
shape
:
tuple
[
PrimExpr
,
...],
)
->
Buffer
:
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Buffer
:
"""Create a buffer from a pointer, shape, and data type.
"""Create a buffer from a pointer, shape, and data type.
Args:
Args:
...
@@ -76,6 +75,7 @@ class BaseTensorProxy:
...
@@ -76,6 +75,7 @@ class BaseTensorProxy:
customizable default values for scope, alignment, and offset factors. It implements
customizable default values for scope, alignment, and offset factors. It implements
the core functionality for creating TIR buffers with specific memory configurations.
the core functionality for creating TIR buffers with specific memory configurations.
"""
"""
default_scope
=
"global"
default_scope
=
"global"
default_align
=
0
default_align
=
0
default_offset_factor
=
0
default_offset_factor
=
0
...
@@ -118,11 +118,9 @@ class BaseTensorProxy:
...
@@ -118,11 +118,9 @@ class BaseTensorProxy:
keys
=
(
keys
,)
keys
=
(
keys
,)
return
self
(
*
keys
)
return
self
(
*
keys
)
def
from_ptr
(
self
,
def
from_ptr
(
pointer_var
:
Var
,
self
,
pointer_var
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
shape
:
tuple
[
PrimExpr
,
...],
)
->
tir
.
Buffer
:
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
"""Create a buffer from a pointer, shape, and data type.
"""Create a buffer from a pointer, shape, and data type.
Args:
Args:
...
@@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy):
...
@@ -151,19 +149,10 @@ class TensorProxy(BaseTensorProxy):
strides
.
append
(
s
)
strides
.
append
(
s
)
return
tuple
(
reversed
(
strides
))
return
tuple
(
reversed
(
strides
))
def
__call__
(
self
,
def
__call__
(
self
,
shape
:
tuple
[
Any
]
|
PrimExpr
|
int
,
dtype
:
str
=
"float32"
,
data
=
None
,
scope
=
None
)
->
tir
.
Buffer
:
shape
:
tuple
[
Any
]
|
PrimExpr
|
int
,
dtype
:
str
=
"float32"
,
data
=
None
,
scope
=
None
)
->
tir
.
Buffer
:
if
isinstance
(
shape
,
(
int
,
PrimExpr
)):
if
isinstance
(
shape
,
(
int
,
PrimExpr
)):
shape
=
(
shape
,)
shape
=
(
shape
,)
return
super
().
__call__
(
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
TensorProxy
.
_construct_strides
(
shape
),
data
=
data
,
scope
=
scope
)
shape
,
dtype
=
dtype
,
strides
=
TensorProxy
.
_construct_strides
(
shape
),
data
=
data
,
scope
=
scope
)
class
StridedTensorProxy
(
BaseTensorProxy
):
class
StridedTensorProxy
(
BaseTensorProxy
):
...
@@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy):
...
@@ -172,11 +161,7 @@ class StridedTensorProxy(BaseTensorProxy):
This class implements the default tensor proxy with global memory scope, with the stride information required.
This class implements the default tensor proxy with global memory scope, with the stride information required.
"""
"""
def
__call__
(
self
,
def
__call__
(
self
,
shape
:
tuple
[
Any
],
strides
:
tuple
[
Any
],
dtype
:
str
=
"float32"
,
scope
=
None
)
->
tir
.
Buffer
:
shape
:
tuple
[
Any
],
strides
:
tuple
[
Any
],
dtype
:
str
=
"float32"
,
scope
=
None
)
->
tir
.
Buffer
:
if
len
(
shape
)
!=
len
(
strides
):
if
len
(
shape
)
!=
len
(
strides
):
raise
ValueError
(
"Invalid shape/strides' dimensions"
)
raise
ValueError
(
"Invalid shape/strides' dimensions"
)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
strides
,
scope
=
scope
)
return
super
().
__call__
(
shape
,
dtype
=
dtype
,
strides
=
strides
,
scope
=
scope
)
...
@@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy):
...
@@ -188,6 +173,7 @@ class FragmentBufferProxy(BaseTensorProxy):
This class represents tensor proxies specifically for local fragment memory,
This class represents tensor proxies specifically for local fragment memory,
typically used in GPU tensor core operations.
typically used in GPU tensor core operations.
"""
"""
default_scope
=
"local.fragment"
default_scope
=
"local.fragment"
...
@@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy):
...
@@ -197,6 +183,7 @@ class SharedBufferProxy(BaseTensorProxy):
This class represents tensor proxies for dynamic shared memory,
This class represents tensor proxies for dynamic shared memory,
commonly used in GPU shared memory operations.
commonly used in GPU shared memory operations.
"""
"""
default_scope
=
"shared.dyn"
default_scope
=
"shared.dyn"
...
@@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy):
...
@@ -206,6 +193,7 @@ class LocalBufferProxy(BaseTensorProxy):
This class represents tensor proxies for local memory scope,
This class represents tensor proxies for local memory scope,
typically used for temporary computations in GPU kernels.
typically used for temporary computations in GPU kernels.
"""
"""
default_scope
=
"local"
default_scope
=
"local"
...
@@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name
...
@@ -216,15 +204,12 @@ Buffer = BufferProxy() # pylint: disable=invalid-name
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
class
BaseTensor
:
class
BaseTensor
:
def
__class_getitem__
(
cls
,
key
):
def
__class_getitem__
(
cls
,
key
):
return
cls
return
cls
def
__getitem__
(
self
,
key
)
->
Any
:
def
__getitem__
(
self
,
key
)
->
Any
:
...
...
def
__setitem__
(
self
,
key
,
value
)
->
None
:
def
__setitem__
(
self
,
key
,
value
)
->
None
:
...
...
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -238,36 +223,26 @@ if TYPE_CHECKING:
...
@@ -238,36 +223,26 @@ if TYPE_CHECKING:
offset_factor
=
None
,
offset_factor
=
None
,
buffer_type
=
""
,
buffer_type
=
""
,
axis_separators
=
None
,
axis_separators
=
None
,
):
):
...
...
@
classmethod
@
classmethod
def
from_ptr
(
cls
,
def
from_ptr
(
pointer_var
:
Var
,
cls
,
pointer_var
:
Var
,
shape
:
Sequence
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
shape
:
Sequence
[
PrimExpr
,
...],
)
->
Self
:
...
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
Self
:
...
class
Tensor
(
BaseTensor
):
class
Tensor
(
BaseTensor
):
...
...
class
StridedTensor
(
BaseTensor
):
class
StridedTensor
(
BaseTensor
):
...
...
class
FragmentBuffer
(
BaseTensor
):
class
FragmentBuffer
(
BaseTensor
):
...
...
class
SharedBuffer
(
BaseTensor
):
class
SharedBuffer
(
BaseTensor
):
...
...
class
LocalBuffer
(
BaseTensor
):
class
LocalBuffer
(
BaseTensor
):
...
...
_T
=
TypeVar
(
'
_T
'
)
_T
=
TypeVar
(
"
_T
"
)
class
Ref
(
Generic
[
_T
],
tir
.
Var
):
class
Ref
(
Generic
[
_T
],
tir
.
Var
):
...
...
else
:
else
:
Tensor
=
TensorProxy
()
# pylint: disable=invalid-name
Tensor
=
TensorProxy
()
# pylint: disable=invalid-name
StridedTensor
=
StridedTensorProxy
()
# pylint: disable=invalid-name
StridedTensor
=
StridedTensorProxy
()
# pylint: disable=invalid-name
...
@@ -275,14 +250,10 @@ else:
...
@@ -275,14 +250,10 @@ else:
SharedBuffer
=
SharedBufferProxy
()
# pylint: disable=invalid-name
SharedBuffer
=
SharedBufferProxy
()
# pylint: disable=invalid-name
LocalBuffer
=
LocalBufferProxy
()
# pylint: disable=invalid-name
LocalBuffer
=
LocalBufferProxy
()
# pylint: disable=invalid-name
class
Ref
:
class
Ref
:
...
...
def
ptr
(
dtype
:
str
|
None
=
None
,
def
ptr
(
dtype
:
str
|
None
=
None
,
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
storage_scope
:
str
=
"global"
,
*
,
is_size_var
:
bool
=
False
)
->
Var
:
"""Create a TIR var that represents a pointer.
"""Create a TIR var that represents a pointer.
Parameters
Parameters
...
@@ -304,8 +275,5 @@ def ptr(dtype: str | None = None,
...
@@ -304,8 +275,5 @@ def ptr(dtype: str | None = None,
return
handle
(
dtype
=
dtype
,
storage_scope
=
storage_scope
,
is_size_var
=
is_size_var
)
return
handle
(
dtype
=
dtype
,
storage_scope
=
storage_scope
,
is_size_var
=
is_size_var
)
def
make_tensor
(
ptr
:
Var
,
def
make_tensor
(
ptr
:
Var
,
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
shape
:
tuple
[
PrimExpr
,
...],
dtype
:
str
=
"float32"
,
strides
:
tuple
[
PrimExpr
,
...]
=
None
)
->
tir
.
Buffer
:
return
Tensor
.
from_ptr
(
ptr
,
shape
,
dtype
,
strides
)
return
Tensor
.
from_ptr
(
ptr
,
shape
,
dtype
,
strides
)
Prev
1
…
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment