Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bbbf4207
Unverified
Commit
bbbf4207
authored
Nov 14, 2025
by
guchaoyang
Committed by
GitHub
Nov 14, 2025
Browse files
Merge branch 'main' into dcu
parents
8f4628e0
5eb30a4f
Changes
286
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1229 additions
and
656 deletions
+1229
-656
tilelang/ir.py
tilelang/ir.py
+14
-14
tilelang/jit/__init__.py
tilelang/jit/__init__.py
+221
-162
tilelang/jit/adapter/torch/metal.py
tilelang/jit/adapter/torch/metal.py
+6
-2
tilelang/jit/adapter/wrapper.py
tilelang/jit/adapter/wrapper.py
+32
-15
tilelang/jit/kernel.py
tilelang/jit/kernel.py
+247
-5
tilelang/language/__init__.py
tilelang/language/__init__.py
+8
-7
tilelang/language/allocate.py
tilelang/language/allocate.py
+58
-4
tilelang/language/ast/ir.py
tilelang/language/ast/ir.py
+3
-0
tilelang/language/atomic.py
tilelang/language/atomic.py
+8
-7
tilelang/language/builtin.py
tilelang/language/builtin.py
+331
-25
tilelang/language/copy.py
tilelang/language/copy.py
+22
-7
tilelang/language/fill.py
tilelang/language/fill.py
+29
-3
tilelang/language/gemm.py
tilelang/language/gemm.py
+65
-332
tilelang/language/loop.py
tilelang/language/loop.py
+111
-0
tilelang/language/parallel.py
tilelang/language/parallel.py
+0
-29
tilelang/language/persistent.py
tilelang/language/persistent.py
+0
-27
tilelang/language/print.py
tilelang/language/print.py
+3
-2
tilelang/language/proxy.py
tilelang/language/proxy.py
+2
-1
tilelang/language/reduce.py
tilelang/language/reduce.py
+68
-12
tilelang/language/symbolics.py
tilelang/language/symbolics.py
+1
-2
No files found.
tilelang/ir.py
View file @
bbbf4207
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
from
tvm.ir.base
import
Node
from
tvm.ir.base
import
Node
from
tvm.runtime
import
Scriptable
from
tvm.runtime
import
Scriptable
import
tvm
.
ffi
import
tvm
_
ffi
from
tvm.target
import
Target
from
tvm.target
import
Target
from
tilelang
import
_ffi_api
from
tilelang
import
_ffi_api
@
tvm
.
ffi
.
register_object
(
"tl.Fill"
)
@
tvm
_
ffi
.
register_object
(
"tl.Fill"
)
class
Fill
(
Node
,
Scriptable
):
class
Fill
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.AtomicAdd"
)
@
tvm
_
ffi
.
register_object
(
"tl.AtomicAdd"
)
class
AtomicAdd
(
Node
,
Scriptable
):
class
AtomicAdd
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.Copy"
)
@
tvm
_
ffi
.
register_object
(
"tl.Copy"
)
class
Copy
(
Node
,
Scriptable
):
class
Copy
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.Conv2DIm2Col"
)
@
tvm
_
ffi
.
register_object
(
"tl.Conv2DIm2Col"
)
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
class
Conv2DIm2ColOp
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.GemmWarpPolicy"
)
@
tvm
_
ffi
.
register_object
(
"tl.GemmWarpPolicy"
)
class
GemmWarpPolicy
(
Node
,
Scriptable
):
class
GemmWarpPolicy
(
Node
,
Scriptable
):
policy_type
:
int
policy_type
:
int
m_warp
:
int
m_warp
:
int
...
@@ -39,41 +39,41 @@ class GemmWarpPolicy(Node, Scriptable):
...
@@ -39,41 +39,41 @@ class GemmWarpPolicy(Node, Scriptable):
return
self
.
m_warp
,
self
.
n_warp
return
self
.
m_warp
,
self
.
n_warp
@
tvm
.
ffi
.
register_object
(
"tl.Gemm"
)
@
tvm
_
ffi
.
register_object
(
"tl.Gemm"
)
class
Gemm
(
Node
,
Scriptable
):
class
Gemm
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.GemmSP"
)
@
tvm
_
ffi
.
register_object
(
"tl.GemmSP"
)
class
GemmSP
(
Node
,
Scriptable
):
class
GemmSP
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.FinalizeReducerOp"
)
@
tvm
_
ffi
.
register_object
(
"tl.FinalizeReducerOp"
)
class
FinalizeReducerOp
(
Node
,
Scriptable
):
class
FinalizeReducerOp
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.ParallelOp"
)
@
tvm
_
ffi
.
register_object
(
"tl.ParallelOp"
)
class
ParallelOp
(
Node
,
Scriptable
):
class
ParallelOp
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.ReduceOp"
)
@
tvm
_
ffi
.
register_object
(
"tl.ReduceOp"
)
class
ReduceOp
(
Node
,
Scriptable
):
class
ReduceOp
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.CumSumOp"
)
@
tvm
_
ffi
.
register_object
(
"tl.CumSumOp"
)
class
CumSumOp
(
Node
,
Scriptable
):
class
CumSumOp
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.RegionOp"
)
@
tvm
_
ffi
.
register_object
(
"tl.RegionOp"
)
class
RegionOp
(
Node
,
Scriptable
):
class
RegionOp
(
Node
,
Scriptable
):
...
...
@
tvm
.
ffi
.
register_object
(
"tl.ReduceType"
)
@
tvm
_
ffi
.
register_object
(
"tl.ReduceType"
)
class
ReduceType
(
Node
,
Scriptable
):
class
ReduceType
(
Node
,
Scriptable
):
...
...
tilelang/jit/__init__.py
View file @
bbbf4207
...
@@ -5,15 +5,25 @@ kernel adapter using TVM.
...
@@ -5,15 +5,25 @@ kernel adapter using TVM.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
import
inspect
from
typing
import
(
from
typing
import
(
Any
,
Any
,
Callable
,
Callable
,
Generic
,
TypeVar
,
overload
,
overload
,
Literal
,
Literal
,
)
)
from
collections.abc
import
Iterable
# Python 3.9 compatibility for ParamSpec
try
:
from
typing
import
ParamSpec
except
ImportError
:
# Python < 3.10
from
typing_extensions
import
ParamSpec
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
from
tilelang.language.v2
import
PrimFunc
from
tilelang.jit.adapter.utils
import
is_metal_target
from
tilelang.jit.adapter.utils
import
is_metal_target
from
tvm.tir
import
PrimFunc
from
tvm.target
import
Target
from
tvm.target
import
Target
from
tilelang.jit.kernel
import
JITKernel
from
tilelang.jit.kernel
import
JITKernel
...
@@ -21,14 +31,20 @@ from tilelang.utils.target import determine_target
...
@@ -21,14 +31,20 @@ from tilelang.utils.target import determine_target
from
tilelang.cache
import
cached
from
tilelang.cache
import
cached
from
os
import
path
,
makedirs
from
os
import
path
,
makedirs
from
logging
import
getLogger
from
logging
import
getLogger
import
functools
from
tilelang.jit.param
import
Kernel
from
tilelang.jit.param
import
Kernel
,
_P
,
_RProg
import
concurrent.futures
from
tqdm.auto
import
tqdm
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
_P
=
ParamSpec
(
'_P'
)
_KP
=
ParamSpec
(
'_KP'
)
_T
=
TypeVar
(
'_T'
)
def
compile
(
def
compile
(
func
:
PrimFunc
=
None
,
func
:
PrimFunc
[
_KP
,
_T
]
=
None
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"dlpack"
,
"ctypes"
,
"cython"
,
"nvrtc"
]
=
"cython"
,
execution_backend
:
Literal
[
"dlpack"
,
"ctypes"
,
"cython"
,
"nvrtc"
]
=
"cython"
,
target
:
str
|
Target
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
...
@@ -36,7 +52,7 @@ def compile(
...
@@ -36,7 +52,7 @@ def compile(
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
)
->
JITKernel
:
)
->
JITKernel
[
_KP
,
_T
]
:
"""
"""
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
Parameters
Parameters
...
@@ -79,159 +95,208 @@ def compile(
...
@@ -79,159 +95,208 @@ def compile(
)
)
class
_JitImplementation
:
def
par_compile
(
funcs
:
Iterable
[
PrimFunc
[
_KP
,
_T
]],
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"dlpack"
,
"ctypes"
,
"cython"
,
"nvrtc"
]
=
"cython"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
----------
funcs : Iterable[tvm.tir.PrimFunc]
The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython", "nvrtc"], optional
Execution backend to use for kernel execution (default: "cython").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation (default: None).
verbose : bool, optional
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
with
concurrent
.
futures
.
ThreadPoolExecutor
(
num_workers
,
'tl-par-comp'
)
as
executor
:
futures
=
[]
future_map
=
{}
for
i
,
func
in
enumerate
(
funcs
):
future
=
executor
.
submit
(
compile
,
func
=
func
,
out_idx
=
out_idx
,
execution_backend
=
execution_backend
,
target
=
target
,
target_host
=
target_host
,
verbose
=
verbose
,
pass_configs
=
pass_configs
,
compile_flags
=
compile_flags
,
)
future_map
[
future
]
=
i
futures
.
append
(
future
)
results
=
[...
for
_
in
futures
]
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Parallel Compiling"
,
):
idx
=
future_map
[
future
]
if
ignore_error
:
try
:
results
[
idx
]
=
future
.
result
()
except
Exception
as
e
:
logger
.
warning
(
f
"Error compiling function at index
{
idx
}
:
{
e
}
"
)
results
[
idx
]
=
None
else
:
results
[
idx
]
=
future
.
result
()
return
results
return
results
@
dataclass
class
JITImpl
(
Generic
[
_P
,
_KP
,
_T
]):
func
:
Callable
[
_P
,
_T
]
|
PrimFunc
[
_KP
,
_T
]
out_idx
:
list
[
int
]
|
int
|
None
out_idx
:
list
[
int
]
|
int
|
None
execution_backend
:
Literal
[
"dlpack"
,
"ctypes"
,
"cython"
]
target
:
str
|
Target
target
:
str
|
Target
target_host
:
str
|
Target
target_host
:
str
|
Target
execution_backend
:
Literal
[
"dlpack"
,
"ctypes"
,
"cython"
]
verbose
:
bool
verbose
:
bool
pass_configs
:
dict
[
str
,
Any
]
|
None
pass_configs
:
dict
[
str
,
Any
]
|
None
debug_root_path
:
str
|
None
debug_root_path
:
str
|
None
compile_flags
:
list
[
str
]
|
str
|
None
compile_flags
:
list
[
str
]
|
str
|
None
func_source
:
str
signature
:
inspect
.
Signature
def
__init__
(
self
,
def
__post_init__
(
self
):
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"dlpack"
,
"ctypes"
,
"cython"
]
=
"cython"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
):
"""
Initializes the JIT compiler decorator.
Parameters
----------
out_idx : Any, optional
Index(es) of the output tensors to return from the compiled kernel
(default: None, meaning all outputs are returned or determined by the kernel itself).
target : Union[str, Target], optional
Compilation target for TVM. Can be a string (e.g., "cuda", "llvm")
or a TVM Target object. If "auto", the target is determined automatically
(default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation, similar to `target` (default: None).
execution_backend : Literal["dlpack", "ctypes", "cython"], optional
The backend used for kernel execution and argument passing.
"dlpack" is generally preferred for zero-copy tensor passing with compatible frameworks.
"ctypes" uses standard C types. "cython" uses Cython for potentially faster execution.
(default: "cython").
verbose : bool, optional
If True, enables verbose logging during compilation (default: False).
pass_configs : Optional[Dict[str, Any]], optional
A dictionary of configurations for TVM's pass context. These can fine-tune
the compilation process. Examples include "tir.disable_vectorize"
(default: None).
debug_root_path : Optional[str], optional
If provided, the compiled kernel's source code will be saved to a file
in this directory. This is useful for debugging the generated code.
If None, no debug information is saved (default: None).
If a relative path is given, it's made absolute relative to the project root
or current working directory.
compile_flags : Optional[Union[List[str], str]], optional
Additional compilation flags to pass to the compiler.
If None, no additional compilation flags are passed (default: None).
"""
self
.
out_idx
=
out_idx
self
.
execution_backend
=
execution_backend
self
.
target
=
target
self
.
target_host
=
target_host
self
.
verbose
=
verbose
self
.
pass_configs
=
pass_configs
self
.
compile_flags
=
compile_flags
# Corrected debug_root_path handling
self
.
debug_root_path
=
debug_root_path
if
self
.
debug_root_path
is
not
None
and
not
path
.
isabs
(
self
.
debug_root_path
):
if
self
.
debug_root_path
is
not
None
and
not
path
.
isabs
(
self
.
debug_root_path
):
try
:
try
:
base_path
=
path
.
dirname
(
path
.
dirname
(
path
.
dirname
(
__file__
)))
base_path
=
path
.
dirname
(
path
.
dirname
(
path
.
dirname
(
__file__
)))
self
.
debug_root_path
=
path
.
join
(
base_path
,
self
.
debug_root_path
)
self
.
debug_root_path
=
path
.
join
(
base_path
,
self
.
debug_root_path
)
except
NameError
:
except
NameError
:
self
.
debug_root_path
=
path
.
abspath
(
self
.
debug_root_path
)
self
.
debug_root_path
=
path
.
abspath
(
self
.
debug_root_path
)
self
.
_kernel_cache
:
dict
[
tuple
,
Kernel
]
=
{}
self
.
_kernel_cache
:
dict
[
tuple
,
Kernel
]
=
{}
# This tells the type checker what the *wrapper* function will return.
def
get_tir
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
PrimFunc
[
_KP
,
_T
]:
# this is for linting, please do not remove it.
program_result_source
=
self
.
func
@
overload
if
isinstance
(
program_result_source
,
PrimFunc
):
def
__call__
(
self
,
func
:
Callable
[
_P
,
_RProg
])
->
Callable
[
_P
,
tuple
[
_RProg
,
Kernel
]]:
program_result
=
program_result_source
...
elif
callable
(
program_result_source
):
program_result
=
program_result_source
(
*
args
,
**
kwargs
)
@
overload
else
:
def
__call__
(
self
,
func
:
Callable
[
_P
,
_RProg
])
->
Callable
[
_P
,
Kernel
]:
raise
ValueError
(
f
"Invalid function type:
{
type
(
program_result_source
)
}
"
)
...
return
program_result
# Actual implementation of __call__
def
par_compile
(
self
,
def
__call__
(
configs
:
Iterable
[
dict
[
str
,
Any
]
|
tuple
[
str
,
Any
]],
self
,
num_workers
:
int
=
None
,
func
:
Callable
[
_P
,
_RProg
]
# func is Union[Callable[_P, _RProg], PrimFunc] in original
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
)
->
Callable
[
_P
,
Any
]:
configs
=
list
(
configs
)
funcs
=
[]
@
functools
.
wraps
(
func
)
for
cfg
in
tqdm
(
configs
,
desc
=
'Elaborating'
):
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
Any
:
if
isinstance
(
cfg
,
tuple
):
# Separate out the tuning parameters from the user's kwargs
funcs
.
append
(
self
.
get_tir
(
*
cfg
))
tune_params
=
kwargs
.
pop
(
'__tune_params'
,
{})
elif
isinstance
(
cfg
,
dict
):
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
funcs
.
append
(
self
.
get_tir
(
**
cfg
))
return_compile_arguments
=
kwargs
.
pop
(
'__return_compile_arguments'
,
False
)
else
:
if
return_compile_arguments
:
raise
ValueError
(
f
"Invalid config type:
{
type
(
cfg
)
}
, expected tuple or dict."
)
compile_args
=
{
return
par_compile
(
'out_idx'
:
self
.
out_idx
,
funcs
,
'execution_backend'
:
self
.
execution_backend
,
out_idx
=
self
.
out_idx
,
'target'
:
self
.
target
,
execution_backend
=
self
.
execution_backend
,
'target_host'
:
self
.
target_host
,
target
=
self
.
target
,
'verbose'
:
self
.
verbose
,
target_host
=
self
.
target_host
,
'pass_configs'
:
self
.
pass_configs
,
verbose
=
self
.
verbose
,
'compile_flags'
:
self
.
compile_flags
,
pass_configs
=
self
.
pass_configs
,
}
compile_flags
=
self
.
compile_flags
,
return
compile_args
num_workers
=
num_workers
,
ignore_error
=
ignore_error
)
key_args_tuple
=
args
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
def
compile
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
JITKernel
[
_KP
,
_T
]:
tuned_key_kwargs_tuple
=
tuple
(
sorted
(
tune_params
.
items
()))
func
=
self
.
get_tir
(
*
args
,
**
kwargs
)
key
=
(
key_args_tuple
,
key_kwargs_tuple
,
tuned_key_kwargs_tuple
)
kernel_result
=
compile
(
func
,
if
key
not
in
self
.
_kernel_cache
:
out_idx
=
self
.
out_idx
,
# Ensure 'func' (the original user function) is used correctly
execution_backend
=
self
.
execution_backend
,
program_result_source
=
func
target
=
self
.
target
,
if
isinstance
(
program_result_source
,
PrimFunc
):
target_host
=
self
.
target_host
,
program_result
=
program_result_source
verbose
=
self
.
verbose
,
elif
callable
(
program_result_source
):
pass_configs
=
self
.
pass_configs
,
program_result
=
program_result_source
(
*
args
,
**
kwargs
,
**
tune_params
)
compile_flags
=
self
.
compile_flags
,
else
:
)
raise
ValueError
(
f
"Invalid function type:
{
type
(
program_result_source
)
}
"
)
if
self
.
debug_root_path
:
kernel_result
=
compile
(
if
isinstance
(
self
.
func
,
PrimFunc
):
program_result
,
func_name
=
self
.
func
.
attrs
[
'global_symbol'
]
out_idx
=
self
.
out_idx
,
else
:
execution_backend
=
self
.
execution_backend
,
func_name
=
getattr
(
self
.
func
,
'__name__'
,
'jit_kernel'
)
target
=
self
.
target
,
kernel_file
=
f
'tilelang_jit_kernel_
{
func_name
}
.c'
target_host
=
self
.
target_host
,
program_file
=
f
'tilelang_jit_program_
{
func_name
}
.py'
verbose
=
self
.
verbose
,
makedirs
(
self
.
debug_root_path
,
exist_ok
=
True
)
pass_configs
=
self
.
pass_configs
,
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
'w'
)
as
f
:
compile_flags
=
self
.
compile_flags
,
print
(
kernel_result
.
get_kernel_source
(),
file
=
f
)
)
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
'w'
)
as
f
:
print
(
func
.
script
(),
file
=
f
)
if
self
.
debug_root_path
:
func_name
=
getattr
(
func
,
'__name__'
,
'jit_kernel'
)
# Use func for name
return
kernel_result
kernel_file
=
f
'tilelang_jit_kernel_
{
func_name
}
.c'
program_file
=
f
'tilelang_jit_program_
{
func_name
}
.py'
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
JITKernel
[
_KP
,
_T
]:
makedirs
(
self
.
debug_root_path
,
exist_ok
=
True
)
# Separate out the tuning parameters from the user's kwargs
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
'w'
)
as
f
:
tune_params
=
kwargs
.
pop
(
'__tune_params'
,
{})
print
(
kernel_result
.
get_kernel_source
(),
file
=
f
)
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
'w'
)
as
f
:
return_compile_arguments
=
kwargs
.
pop
(
'__return_compile_arguments'
,
False
)
print
(
program_result
.
script
(),
file
=
f
)
if
return_compile_arguments
:
compile_args
=
{
self
.
_kernel_cache
[
key
]
=
kernel_result
'out_idx'
:
self
.
out_idx
,
'execution_backend'
:
self
.
execution_backend
,
return
self
.
_kernel_cache
[
key
]
'target'
:
self
.
target
,
'target_host'
:
self
.
target_host
,
return
wrapper
'verbose'
:
self
.
verbose
,
'pass_configs'
:
self
.
pass_configs
,
'compile_flags'
:
self
.
compile_flags
,
}
return
compile_args
key_args_tuple
=
args
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
tuned_key_kwargs_tuple
=
tuple
(
sorted
(
tune_params
.
items
()))
key
=
(
key_args_tuple
,
key_kwargs_tuple
,
tuned_key_kwargs_tuple
)
if
key
not
in
self
.
_kernel_cache
:
self
.
_kernel_cache
[
key
]
=
self
.
compile
(
*
args
,
**
kwargs
,
**
tune_params
)
return
self
.
_kernel_cache
[
key
]
@
overload
def
jit
(
func
:
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]])
->
JITImpl
[
_P
,
_KP
,
_T
]:
...
@
overload
def
jit
(
*
,
# Indicates subsequent arguments are keyword-only
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
Literal
[
"dlpack"
,
"ctypes"
,
"cython"
,
"nvrtc"
]
=
"cython"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
)
->
Callable
[[
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]]],
JITImpl
[
_P
,
_KP
,
_T
]]:
...
def
jit
(
# This is the new public interface
def
jit
(
# This is the new public interface
func
:
Callable
[
_P
,
_
RProg
]
|
PrimFunc
|
None
=
None
,
func
:
Callable
[
_P
,
_
T
]
|
PrimFunc
|
None
=
None
,
*
,
# Indicates subsequent arguments are keyword-only
*
,
# Indicates subsequent arguments are keyword-only
out_idx
:
Any
=
None
,
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
...
@@ -275,32 +340,26 @@ def jit( # This is the new public interface
...
@@ -275,32 +340,26 @@ def jit( # This is the new public interface
if
isinstance
(
compile_flags
,
str
):
if
isinstance
(
compile_flags
,
str
):
compile_flags
=
[
compile_flags
]
compile_flags
=
[
compile_flags
]
if
callable
(
func
):
def
decorator
(
func
:
Callable
[
_P
,
_T
])
->
JITImpl
[
_P
,
_T
]:
# Case 1: Used as @jit (func_or_out_idx is the function, others are defaults)
if
isinstance
(
func
,
PrimFunc
):
# Create a default _JitImplementation instance and apply it to the function.
orig_func
=
func
.
orig_func
default_decorator
=
_JitImplementation
(
else
:
out_idx
=
out_idx
,
# Explicitly None for the default case
orig_func
=
func
target
=
target
,
return
JITImpl
(
target_host
=
target_host
,
func
,
out_idx
=
out_idx
,
execution_backend
=
execution_backend
,
execution_backend
=
execution_backend
,
verbose
=
verbose
,
pass_configs
=
pass_configs
,
debug_root_path
=
debug_root_path
,
compile_flags
=
compile_flags
)
return
default_decorator
(
func
)
elif
isinstance
(
func
,
PrimFunc
):
raise
ValueError
(
"Use tilelang.jit to decorate prim_func is not supported yet."
)
else
:
# Case 2: Used as @jit(...) to configure, or func_or_out_idx is meant as out_idx.
# Create a _JitImplementation instance with the provided/defaulted arguments.
# This instance is a decorator that will be applied to the function later.
configured_decorator
=
_JitImplementation
(
out_idx
=
out_idx
,
# Pass along; could be an actual out_idx or None
target
=
target
,
target
=
target
,
target_host
=
target_host
,
target_host
=
target_host
,
execution_backend
=
execution_backend
,
verbose
=
verbose
,
verbose
=
verbose
,
pass_configs
=
pass_configs
,
pass_configs
=
pass_configs
,
debug_root_path
=
debug_root_path
,
debug_root_path
=
debug_root_path
,
compile_flags
=
compile_flags
)
compile_flags
=
compile_flags
,
return
configured_decorator
func_source
=
inspect
.
getsource
(
orig_func
),
signature
=
inspect
.
signature
(
orig_func
),
)
if
func
is
not
None
:
return
decorator
(
func
)
else
:
return
decorator
tilelang/jit/adapter/torch/metal.py
View file @
bbbf4207
...
@@ -27,7 +27,11 @@ class MetalKernelAdapter(BaseKernelAdapter):
...
@@ -27,7 +27,11 @@ class MetalKernelAdapter(BaseKernelAdapter):
# compile_flags: Optional[List[str]] = None
# compile_flags: Optional[List[str]] = None
):
):
self
.
kernel_global_source
=
kernel_global_source
self
.
kernel_global_source
=
kernel_global_source
self
.
kernel_name
=
func_or_mod
.
__name__
+
'_kernel'
if
isinstance
(
func_or_mod
,
tir
.
PrimFunc
):
func_name
=
func_or_mod
.
attrs
[
'global_symbol'
]
else
:
func_name
=
func_or_mod
.
__name__
self
.
kernel_name
=
func_name
+
'_kernel'
self
.
verbose
=
verbose
self
.
verbose
=
verbose
self
.
block_info
=
[
1
,
1
,
1
]
self
.
block_info
=
[
1
,
1
,
1
]
...
@@ -43,7 +47,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
...
@@ -43,7 +47,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self
.
grid_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
self
.
grid_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
break
break
else
:
else
:
raise
AssertionError
(
f
'no kernel with name
{
func_
or_mod
.
__
name
__
}
'
)
raise
AssertionError
(
f
'no kernel with name
{
func_name
}
'
)
# print(self.block_info, self.grid_info)
# print(self.block_info, self.grid_info)
super
().
__init__
(
func_or_mod
,
result_idx
=
result_idx
,
params
=
params
)
super
().
__init__
(
func_or_mod
,
result_idx
=
result_idx
,
params
=
params
)
...
...
tilelang/jit/adapter/wrapper.py
View file @
bbbf4207
...
@@ -257,6 +257,12 @@ class TLCUDASourceWrapper:
...
@@ -257,6 +257,12 @@ class TLCUDASourceWrapper:
def
_pythonic_expr
(
self
,
expr
:
tvm
.
tir
.
PrimExpr
)
->
str
:
def
_pythonic_expr
(
self
,
expr
:
tvm
.
tir
.
PrimExpr
)
->
str
:
return
pythonic_expr
(
expr
,
self
.
_TYPE_MAP
)
return
pythonic_expr
(
expr
,
self
.
_TYPE_MAP
)
def
_lookup_type
(
self
,
dtype
:
str
|
Any
)
->
str
:
key
=
dtype
if
isinstance
(
dtype
,
str
)
else
str
(
dtype
)
result
=
self
.
_TYPE_MAP
.
get
(
key
)
assert
result
is
not
None
,
f
"Unsupported dtype
{
dtype
}
"
return
result
def
is_tma_descriptor_arg
(
self
,
arg_name
:
str
)
->
bool
:
def
is_tma_descriptor_arg
(
self
,
arg_name
:
str
)
->
bool
:
return
arg_name
in
self
.
prim_func
.
buffer_map
return
arg_name
in
self
.
prim_func
.
buffer_map
...
@@ -274,10 +280,10 @@ class TLCUDASourceWrapper:
...
@@ -274,10 +280,10 @@ class TLCUDASourceWrapper:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
function_args
.
append
({
"name"
:
buffer
.
data
.
name
,
"name"
:
buffer
.
data
.
name
,
"type"
:
self
.
_
TYPE_MAP
[
buffer
.
dtype
]
+
"* __restrict__"
,
"type"
:
self
.
_
lookup_type
(
buffer
.
dtype
)
+
"* __restrict__"
,
})
})
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_
TYPE_MAP
[
param
.
dtype
]
})
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_
lookup_type
(
param
.
dtype
)
})
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
...
@@ -717,6 +723,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -717,6 +723,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"float16"
:
"ctypes.c_uint16"
,
"float16"
:
"ctypes.c_uint16"
,
"bfloat16"
:
"ctypes.c_uint16"
,
"bfloat16"
:
"ctypes.c_uint16"
,
"float8_e4m3"
:
"ctypes.c_uint8"
,
"float8_e4m3"
:
"ctypes.c_uint8"
,
"float8_e4m3fn"
:
"ctypes.c_uint8"
,
"float8_e5m2"
:
"ctypes.c_uint8"
,
"float8_e5m2"
:
"ctypes.c_uint8"
,
"float64"
:
"ctypes.c_double"
,
"float64"
:
"ctypes.c_double"
,
"int64"
:
"ctypes.c_int64"
,
"int64"
:
"ctypes.c_int64"
,
...
@@ -753,7 +760,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -753,7 +760,7 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"type"
:
"ctypes.c_void_p"
,
"type"
:
"ctypes.c_void_p"
,
})
})
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_
TYPE_MAP
[
param
.
dtype
]
})
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_
lookup_type
(
param
.
dtype
)
})
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
...
@@ -923,6 +930,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
...
@@ -923,6 +930,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"float16"
:
"half_t"
,
"float16"
:
"half_t"
,
"bfloat16"
:
"bfloat16_t"
,
"bfloat16"
:
"bfloat16_t"
,
"float8_e4m3"
:
"fp8_e4_t"
,
"float8_e4m3"
:
"fp8_e4_t"
,
"float8_e4m3fn"
:
"fp8_e4_t"
,
"float8_e5m2"
:
"fp8_e5_t"
,
"float8_e5m2"
:
"fp8_e5_t"
,
"float8_e4m3fnuz"
:
"fp8_e4_t"
,
"float8_e4m3fnuz"
:
"fp8_e4_t"
,
"e4m3fnuz_float8"
:
"fp8_e4_t"
,
"e4m3fnuz_float8"
:
"fp8_e4_t"
,
...
@@ -969,16 +977,19 @@ class TLCPUSourceWrapper:
...
@@ -969,16 +977,19 @@ class TLCPUSourceWrapper:
"float32"
:
"float"
,
"float32"
:
"float"
,
"float16"
:
"half"
,
"float16"
:
"half"
,
"int32"
:
"int32_t"
,
"int32"
:
"int32_t"
,
"int8"
:
"int8_t"
,
"uint8"
:
"uint8_t"
,
"int16"
:
"int16_t"
,
"uint16"
:
"uint16_t"
,
"int64"
:
"int64_t"
,
"uint64"
:
"uint64_t"
,
"float64"
:
"double"
,
"bool"
:
"bool"
,
"uchar"
:
"uchar"
,
}
}
INIT_FUNC
=
textwrap
.
dedent
(
'''
# Use common init with error buffer and get_last_error for CPU backend as well
#ifdef __cplusplus
INIT_FUNC
=
PREDEF_INIT_FUNC
.
format
(
""
)
extern "C"
#endif
int32_t init() {
return 0;
}
'''
)
CALL_PREFIX
=
textwrap
.
dedent
(
"""
CALL_PREFIX
=
textwrap
.
dedent
(
"""
#ifdef __cplusplus
#ifdef __cplusplus
...
@@ -1014,6 +1025,12 @@ class TLCPUSourceWrapper:
...
@@ -1014,6 +1025,12 @@ class TLCPUSourceWrapper:
self
.
libpath
:
str
|
None
=
None
self
.
libpath
:
str
|
None
=
None
self
.
lib_code
:
str
|
None
=
self
.
update_lib_code
(
source
)
self
.
lib_code
:
str
|
None
=
self
.
update_lib_code
(
source
)
def
_lookup_type
(
self
,
dtype
:
str
|
Any
)
->
str
:
key
=
dtype
if
isinstance
(
dtype
,
str
)
else
str
(
dtype
)
result
=
self
.
_TYPE_MAP
.
get
(
key
)
assert
result
is
not
None
,
f
"Unsupported dtype
{
dtype
}
"
return
result
def
create_call_func
(
self
,
code
,
function_informations
):
def
create_call_func
(
self
,
code
,
function_informations
):
# Extract the set of dynamic symbolic names used in the primary function
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set
=
self
.
get_dynamic_symbolic_set
(
self
.
prim_func
)
dynamic_symbolic_set
=
self
.
get_dynamic_symbolic_set
(
self
.
prim_func
)
...
@@ -1025,10 +1042,10 @@ class TLCPUSourceWrapper:
...
@@ -1025,10 +1042,10 @@ class TLCPUSourceWrapper:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
function_args
.
append
({
"name"
:
buffer
.
name
,
"name"
:
buffer
.
name
,
"type"
:
self
.
_
TYPE_MAP
[
buffer
.
dtype
]
+
"*"
,
"type"
:
self
.
_
lookup_type
(
buffer
.
dtype
)
+
"*"
,
})
})
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_
TYPE_MAP
[
param
.
dtype
]
})
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_
lookup_type
(
param
.
dtype
)
})
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
...
@@ -1093,8 +1110,8 @@ class TLCPUSourceWrapper:
...
@@ -1093,8 +1110,8 @@ class TLCPUSourceWrapper:
return
dynamic_symbolic_set
return
dynamic_symbolic_set
def
get_cpu_init_func
(
self
):
def
get_cpu_init_func
(
self
):
init_funcs
=
self
.
INIT_FUNC
# Provide init() and get_last_error() for CPU backend
return
init_funcs
return
self
.
INIT_FUNC
def
update_lib_code
(
self
,
code
:
str
):
def
update_lib_code
(
self
,
code
:
str
):
# Update the library code with the given code string
# Update the library code with the given code string
...
...
tilelang/jit/kernel.py
View file @
bbbf4207
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Literal
from
typing
import
Any
,
Callable
,
Generic
,
Literal
,
TypeVar
# Python 3.9 compatibility for ParamSpec
from
tilelang.jit.adapter.utils
import
is_metal_target
try
:
from
typing
import
ParamSpec
except
ImportError
:
# Python < 3.10
from
typing_extensions
import
ParamSpec
from
tilelang.jit.adapter.utils
import
is_metal_target
,
is_cuda_target
from
tvm.target
import
Target
from
tvm.target
import
Target
from
tvm.tir
import
PrimFunc
from
tvm.tir
import
PrimFunc
...
@@ -13,12 +18,17 @@ from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, Cython
...
@@ -13,12 +18,17 @@ from tilelang.jit.adapter import (BaseKernelAdapter, CtypesKernelAdapter, Cython
NVRTCKernelAdapter
,
TorchDLPackKernelAdapter
,
MetalKernelAdapter
)
NVRTCKernelAdapter
,
TorchDLPackKernelAdapter
,
MetalKernelAdapter
)
from
tilelang.profiler
import
Profiler
,
TensorSupplyType
from
tilelang.profiler
import
Profiler
,
TensorSupplyType
from
tilelang.utils.target
import
determine_target
from
tilelang.utils.target
import
determine_target
from
tilelang.contrib
import
nvcc
as
tl_nvcc
import
logging
import
logging
import
os
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_P
=
ParamSpec
(
'_P'
)
_T
=
TypeVar
(
'_T'
)
class
JITKernel
:
class
JITKernel
(
Generic
[
_P
,
_T
])
:
"""
"""
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
...
@@ -170,7 +180,7 @@ class JITKernel:
...
@@ -170,7 +180,7 @@ class JITKernel:
instance
.
torch_function
=
instance
.
adapter
.
func
instance
.
torch_function
=
instance
.
adapter
.
func
return
instance
return
instance
def
__call__
(
self
,
*
args
:
Any
,
**
kwds
:
Any
)
->
Any
:
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwds
:
_P
.
kwargs
)
->
_T
:
"""
"""
Invokes the compiled function with the given arguments.
Invokes the compiled function with the given arguments.
...
@@ -404,6 +414,110 @@ class JITKernel:
...
@@ -404,6 +414,110 @@ class JITKernel:
def
run_once
(
self
,
func
:
Callable
|
None
=
None
)
->
None
:
def
run_once
(
self
,
func
:
Callable
|
None
=
None
)
->
None
:
return
self
.
get_profiler
().
run_once
(
func
)
return
self
.
get_profiler
().
run_once
(
func
)
def
show_source
(
self
,
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
)
->
None
:
"""
Print generated source code to stdout.
Parameters
----------
which : Literal["kernel", "host", "both"], optional
Select which source to print. Defaults to "kernel".
Examples
--------
>>> jit_kernel.show_source() # print kernel source
>>> jit_kernel.show_source("host") # print host source
>>> jit_kernel.show_source("both") # print both sources
"""
try
:
if
which
==
"kernel"
:
src
=
self
.
get_kernel_source
()
print
(
src
)
elif
which
==
"host"
:
src
=
self
.
get_host_source
()
# Host is generally C/C++
print
(
src
)
elif
which
==
"both"
:
print
(
"===== Kernel Source ====="
)
ksrc
=
self
.
get_kernel_source
()
print
(
ksrc
)
print
(
"===== Host Source ====="
)
hsrc
=
self
.
get_host_source
()
print
(
hsrc
)
else
:
raise
ValueError
(
f
"Unknown option for 'which':
{
which
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to show source code:
{
e
}
"
)
def
export_sources
(
self
,
kernel_path
:
str
|
None
=
None
,
host_path
:
str
|
None
=
None
)
->
None
:
"""
Export generated source code to files.
Parameters
----------
kernel_path : Optional[str]
Destination file path to write the kernel source. If None, skips writing kernel code.
host_path : Optional[str]
Destination file path to write the host source. If None, skips writing host code.
Examples
--------
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> jit_kernel.export_sources(host_path="/tmp/host.cc")
>>> jit_kernel.export_sources(
... kernel_path="/tmp/kernel.cu",
... host_path="/tmp/host.cc",
... )
"""
if
kernel_path
is
None
and
host_path
is
None
:
raise
ValueError
(
"At least one of kernel_path or host_path must be provided."
)
try
:
if
kernel_path
is
not
None
:
dir_path
=
os
.
path
.
dirname
(
kernel_path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
kernel_path
,
'w'
)
as
f
:
f
.
write
(
self
.
get_kernel_source
())
if
host_path
is
not
None
:
dir_path
=
os
.
path
.
dirname
(
host_path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
host_path
,
'w'
)
as
f
:
f
.
write
(
self
.
get_host_source
())
except
Exception
as
e
:
logger
.
error
(
f
"Failed to export sources:
{
e
}
"
)
# Backward compatibility alias (deprecated)
def
print_source_code
(
self
,
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
,
file
:
str
|
None
=
None
)
->
None
:
"""
Deprecated: use show_source() or export_sources() instead.
Parameters
----------
which : Literal["kernel", "host", "both"], optional
Kept for backward compatibility with printing behavior.
file : Optional[str]
If provided, behaves like export_sources(kernel_path=file).
Examples
--------
>>> # New API (preferred)
>>> jit_kernel.show_source("both")
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
"""
logger
.
warning
(
"print_source_code is deprecated; use show_source() or export_sources() instead."
)
if
file
is
not
None
:
# Historical behavior wrote only kernel source when file provided
self
.
export_sources
(
kernel_path
=
file
)
else
:
self
.
show_source
(
which
=
which
)
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
ref_latency
:
float
)
->
JITKernel
:
ref_latency
:
float
)
->
JITKernel
:
"""
"""
...
@@ -483,3 +597,131 @@ class JITKernel:
...
@@ -483,3 +597,131 @@ class JITKernel:
# Export the compiled kernel function to a shared library file.
# Export the compiled kernel function to a shared library file.
self
.
rt_module
.
export_library
(
kernel_file
)
self
.
rt_module
.
export_library
(
kernel_file
)
def
_get_ptx
(
self
,
verbose
:
bool
|
None
=
None
)
->
str
:
"""
Compile and return PTX for the current kernel (CUDA only).
Parameters
----------
verbose : Optional[bool]
Whether to enable verbose NVRTC logs. Defaults to self.verbose.
Returns
-------
str
The compiled PTX text.
"""
if
not
is_cuda_target
(
self
.
target
):
raise
ValueError
(
"PTX is only available for CUDA targets."
)
# Prefer NVCC for PTX generation via contrib helper
code
=
self
.
get_kernel_source
()
if
verbose
is
None
:
verbose
=
self
.
verbose
# Ensure target is set so nvcc picks correct arch via Target.current()
with
self
.
target
:
return
tl_nvcc
.
get_ptx_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_ptx
(
self
)
->
None
:
"""
Print compiled PTX for the kernel (CUDA only).
Examples
--------
>>> jit_kernel.show_ptx()
"""
try
:
ptx
=
self
.
_get_ptx
()
print
(
ptx
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to show PTX:
{
e
}
"
)
def
export_ptx
(
self
,
path
:
str
)
->
None
:
"""
Export compiled PTX to a file (CUDA only).
Parameters
----------
path : str
Destination file path to write PTX.
Examples
--------
>>> jit_kernel.export_ptx("/tmp/kernel.ptx")
"""
if
not
path
:
raise
ValueError
(
"path must be provided to export PTX"
)
try
:
ptx
=
self
.
_get_ptx
()
dir_path
=
os
.
path
.
dirname
(
path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
path
,
"w"
)
as
f
:
f
.
write
(
ptx
)
logger
.
info
(
f
"PTX saved to
{
os
.
path
.
abspath
(
path
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to export PTX:
{
e
}
"
)
def
_get_sass
(
self
,
verbose
:
bool
|
None
=
None
)
->
str
:
"""
Compile and return SASS for the current kernel (CUDA only).
Parameters
----------
verbose : Optional[bool]
Whether to enable verbose tool logs. Defaults to self.verbose.
Returns
-------
str
The disassembled SASS text.
"""
if
not
is_cuda_target
(
self
.
target
):
raise
ValueError
(
"SASS is only available for CUDA targets."
)
code
=
self
.
get_kernel_source
()
if
verbose
is
None
:
verbose
=
self
.
verbose
with
self
.
target
:
return
tl_nvcc
.
get_sass_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_sass
(
self
)
->
None
:
"""
Print disassembled SASS for the kernel (CUDA only).
Examples
--------
>>> jit_kernel.show_sass()
"""
try
:
sass
=
self
.
_get_sass
()
print
(
sass
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to show SASS:
{
e
}
"
)
def
export_sass
(
self
,
path
:
str
)
->
None
:
"""
Export disassembled SASS to a file (CUDA only).
Parameters
----------
path : str
Destination file path to write SASS.
Examples
--------
>>> jit_kernel.export_sass("/tmp/kernel.sass")
"""
if
not
path
:
raise
ValueError
(
"path must be provided to export SASS"
)
try
:
sass
=
self
.
_get_sass
()
dir_path
=
os
.
path
.
dirname
(
path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
path
,
"w"
)
as
f
:
f
.
write
(
sass
)
logger
.
info
(
f
"SASS saved to
{
os
.
path
.
abspath
(
path
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to export SASS:
{
e
}
"
)
tilelang/language/__init__.py
View file @
bbbf4207
...
@@ -8,9 +8,9 @@ from __future__ import annotations
...
@@ -8,9 +8,9 @@ from __future__ import annotations
# upstream tir script is fully compatible
# upstream tir script is fully compatible
from
tvm.script.parser.tir
import
*
from
tvm.script.parser.tir
import
*
from
.
import
overrides
as
_overrides
# noqa: F401
from
.
import
overrides
as
_overrides
# noqa: F401
from
.tir
import
(
prim_func
,
# noqa: F401
# from .tir import
prim_func,
macro,
# noqa: F401
)
from
.v2
import
*
# noqa: F401
from
.tir.ir
import
*
# noqa: F401
from
.tir.ir
import
*
# noqa: F401
from
tilelang.layout
import
Layout
,
Fragment
# noqa: F401
from
tilelang.layout
import
Layout
,
Fragment
# noqa: F401
from
.proxy
import
(
from
.proxy
import
(
...
@@ -23,9 +23,7 @@ from .proxy import (
...
@@ -23,9 +23,7 @@ from .proxy import (
SharedBuffer
,
# noqa: F401
SharedBuffer
,
# noqa: F401
LocalBuffer
,
# noqa: F401
LocalBuffer
,
# noqa: F401
)
)
from
.parallel
import
Parallel
# noqa: F401
from
.loop
import
serial
,
Parallel
,
Persistent
,
Pipelined
# noqa: F401
from
.pipeline
import
Pipelined
# noqa: F401
from
.persistent
import
Persistent
# noqa: F401
from
.frame
import
has_let_value
,
get_let_value
# noqa: F401
from
.frame
import
has_let_value
,
get_let_value
# noqa: F401
from
.math_intrinsics
import
*
# noqa: F401
from
.math_intrinsics
import
*
# noqa: F401
from
.kernel
import
(
from
.kernel
import
(
...
@@ -46,9 +44,12 @@ from .allocate import (
...
@@ -46,9 +44,12 @@ from .allocate import (
alloc_tmem
,
# noqa: F401
alloc_tmem
,
# noqa: F401
alloc_reducer
,
# noqa: F401
alloc_reducer
,
# noqa: F401
alloc_descriptor
,
# noqa: F401
alloc_descriptor
,
# noqa: F401
alloc_wgmma_desc
,
# noqa: F401
alloc_tcgen05_smem_desc
,
# noqa: F401
alloc_tcgen05_instr_desc
,
# noqa: F401
)
)
from
.copy
import
copy
,
c2d_im2col
# noqa: F401
from
.copy
import
copy
,
c2d_im2col
# noqa: F401
from
.gemm
import
GemmWarpPolicy
,
gemm
,
gemm_v2
# noqa: F401
from
.gemm
import
GemmWarpPolicy
,
gemm
,
gemm_v1
,
gemm_v2
# noqa: F401
from
.experimental.gemm_sp
import
gemm_sp
# noqa: F401
from
.experimental.gemm_sp
import
gemm_sp
# noqa: F401
from
.fill
import
fill
,
clear
# noqa: F401
from
.fill
import
fill
,
clear
# noqa: F401
from
.reduce
import
(
from
.reduce
import
(
...
...
tilelang/language/allocate.py
View file @
bbbf4207
...
@@ -15,10 +15,13 @@ with the appropriate memory scope.
...
@@ -15,10 +15,13 @@ with the appropriate memory scope.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
overload
,
Literal
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
from
tvm.script
import
tir
as
T
from
tvm.script
import
tir
as
T
from
tvm.tir
import
PrimExpr
from
tvm.tir
import
PrimExpr
from
tvm.script.parser.tir
import
block_attr
from
tvm.script.parser.tir
import
block_attr
from
tvm.tir.buffer
import
Buffer
from
tvm.tir.expr
import
FloatImm
,
IntImm
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
...
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
...
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
@
overload
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
'local.var'
)
->
Buffer
:
...
@
overload
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
'local.var'
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
"""Allocate a single-element variable buffer.
"""Allocate a single-element variable buffer.
...
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
...
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
init (PrimExpr, optional): The optional initializer value. When provided,
init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead
the generated code will initialize the variable with this value instead
of defaulting to zero.
of defaulting to zero.
Examples:
a = T.alloc_var('int32', 1) # var with init 1
a = T.alloc_var('int32', 'local.var') # var with local.var scope
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
Returns:
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
"""
...
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
...
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
buffer
=
T
.
alloc_buffer
([
1
],
dtype
,
scope
=
parsed_scope
)
buffer
=
T
.
alloc_buffer
([
1
],
dtype
,
scope
=
parsed_scope
)
if
parsed_init
is
not
None
:
if
parsed_init
is
not
None
:
block_attr
({
"tl.local_var_init"
:
{
buffer
.
data
:
parsed_init
}})
if
isinstance
(
parsed_init
,
(
int
,
float
,
IntImm
,
FloatImm
)):
block_attr
({
"tl.local_var_init"
:
{
buffer
.
data
:
parsed_init
}})
else
:
T
.
buffer_store
(
buffer
,
parsed_init
,
0
)
return
buffer
return
buffer
...
@@ -194,10 +218,40 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
...
@@ -194,10 +218,40 @@ def alloc_reducer(shape, dtype, op="sum", replication=None):
return
reducer
return
reducer
def
alloc_descriptor
(
dtype
=
"uint64"
,
scope
=
"local.descriptor"
):
DescKind
=
Literal
[
"wgmma"
,
"tcgen05_smem"
,
"tcgen05_instr"
]
"""Allocate a descriptor buffer for wgmma and utcmma.
def
alloc_descriptor
(
kind
:
DescKind
=
"wgmma"
,
dtype
:
str
=
"uint64"
,
):
"""Allocate a descriptor buffer for WGMMA and TCGEN5.MMA.
Args:
kind: The descriptor kind, one of "wgmma", "tcgen05" ("utcmma" as alias).
Returns:
Returns:
T.Buffer: A TVM buffer object allocated as a descriptor
T.Buffer: A TVM buffer object allocated as a descriptor
"""
"""
scope
=
"local.descriptor."
+
kind
# Buffer naming via `name` is not supported by this TVM builder signature;
# keep parameter for forward-compat, but do not pass it.
return
T
.
alloc_buffer
([
1
],
dtype
,
scope
=
scope
)
return
T
.
alloc_buffer
([
1
],
dtype
,
scope
=
scope
)
def
alloc_wgmma_desc
(
dtype
:
str
=
"uint64"
):
return
alloc_descriptor
(
"wgmma"
,
dtype
=
dtype
)
def
alloc_tcgen05_smem_desc
(
dtype
:
str
=
"uint64"
):
return
alloc_descriptor
(
"tcgen05_smem"
,
dtype
=
dtype
)
def
alloc_tcgen05_instruction_desc
(
dtype
:
str
=
"uint32"
):
return
alloc_descriptor
(
"tcgen05_instr"
,
dtype
=
dtype
)
# Alias: short name consistent with imports
def
alloc_tcgen05_instr_desc
(
dtype
:
str
=
"uint32"
):
return
alloc_tcgen05_instruction_desc
(
dtype
)
tilelang/language/ast/ir.py
View file @
bbbf4207
...
@@ -1894,6 +1894,8 @@ ptx_mma = _dtype_forward(_tir_op.ptx_mma)
...
@@ -1894,6 +1894,8 @@ ptx_mma = _dtype_forward(_tir_op.ptx_mma)
ptx_mma_sp
=
_dtype_forward
(
_tir_op
.
ptx_mma_sp
)
ptx_mma_sp
=
_dtype_forward
(
_tir_op
.
ptx_mma_sp
)
ptx_wgmma_ss
=
_dtype_forward
(
_tir_op
.
ptx_wgmma_ss
)
ptx_wgmma_ss
=
_dtype_forward
(
_tir_op
.
ptx_wgmma_ss
)
ptx_wgmma_rs
=
_dtype_forward
(
_tir_op
.
ptx_wgmma_rs
)
ptx_wgmma_rs
=
_dtype_forward
(
_tir_op
.
ptx_wgmma_rs
)
ptx_tcgen05_mma_ss
=
_dtype_forward
(
_tir_op
.
ptx_tcgen05_mma_ss
)
ptx_tcgen05_mma_ts
=
_dtype_forward
(
_tir_op
.
ptx_tcgen05_mma_ts
)
ptx_ldmatrix
=
_dtype_forward
(
_tir_op
.
ptx_ldmatrix
)
ptx_ldmatrix
=
_dtype_forward
(
_tir_op
.
ptx_ldmatrix
)
ptx_cp_async
=
_dtype_forward
(
_tir_op
.
ptx_cp_async
)
ptx_cp_async
=
_dtype_forward
(
_tir_op
.
ptx_cp_async
)
ptx_cp_async_bulk
=
_dtype_forward
(
_tir_op
.
ptx_cp_async_bulk
)
ptx_cp_async_bulk
=
_dtype_forward
(
_tir_op
.
ptx_cp_async_bulk
)
...
@@ -2145,6 +2147,7 @@ __all__ = [
...
@@ -2145,6 +2147,7 @@ __all__ = [
"ptx_mma_sp"
,
"ptx_mma_sp"
,
"ptx_wgmma_ss"
,
"ptx_wgmma_ss"
,
"ptx_wgmma_rs"
,
"ptx_wgmma_rs"
,
"ptx_tcgen05_mma_ss"
,
"ptx_ldmatrix"
,
"ptx_ldmatrix"
,
"ptx_cp_async"
,
"ptx_cp_async"
,
"ptx_cp_async_bulk"
,
"ptx_cp_async_bulk"
,
...
...
tilelang/language/atomic.py
View file @
bbbf4207
...
@@ -6,8 +6,8 @@ from __future__ import annotations
...
@@ -6,8 +6,8 @@ from __future__ import annotations
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tvm
import
ir
,
tir
from
tvm
import
ir
,
tir
from
tvm.tir
import
PrimExpr
,
Buffer
,
BufferRegion
,
Var
,
op
from
tvm.tir
import
PrimExpr
,
Buffer
,
BufferRegion
,
Var
,
op
from
tilelang.language.utils
import
buffer_to_tile_region
,
buffer_region_to_tile_region
,
buffer_load_to_tile_region
from
tilelang.language.utils
import
buffer_region_to_tile_region
,
buffer_load_to_tile_region
from
tilelang.utils.language
import
get_buffer_region_from_load
from
tilelang.utils.language
import
get_buffer_region_from_load
,
legalize_pairwise_extents
_MEMORY_ORDER_ID_MAP
=
{
_MEMORY_ORDER_ID_MAP
=
{
"relaxed"
:
0
,
"relaxed"
:
0
,
...
@@ -201,13 +201,14 @@ def atomic_add(dst: Buffer,
...
@@ -201,13 +201,14 @@ def atomic_add(dst: Buffer,
assert
src_extent
or
dst_extent
,
"Can't deduce atomicadd extents from args"
assert
src_extent
or
dst_extent
,
"Can't deduce atomicadd extents from args"
src_extent
=
list
(
src_extent
)
if
src_extent
else
[
1
]
*
len
(
dst_extent
)
src_extent
=
list
(
src_extent
)
if
src_extent
else
[
1
]
*
len
(
dst_extent
)
dst_extent
=
list
(
dst_extent
)
if
dst_extent
else
[
1
]
*
len
(
src_extent
)
dst_extent
=
list
(
dst_extent
)
if
dst_extent
else
[
1
]
*
len
(
src_extent
)
extent
=
max
(
src_extent
,
dst_extent
)
src_extent
,
dst_extent
=
legalize_pairwise_extents
(
src_extent
,
dst_extent
)
def
_to_region
(
data
,
access_type
):
def
_to_region
(
data
,
access_type
,
extent
):
if
isinstance
(
data
,
tir
.
Var
)
and
T
.
has_let_value
(
data
):
if
isinstance
(
data
,
tir
.
Var
)
and
T
.
has_let_value
(
data
):
data
=
T
.
get_let_value
(
data
)
data
=
T
.
get_let_value
(
data
)
if
isinstance
(
data
,
tir
.
Buffer
):
if
isinstance
(
data
,
tir
.
Buffer
):
return
buffer_to_tile_region
(
data
,
access_type
)
zeros
=
[
tir
.
IntImm
(
"int32"
,
0
)
for
_
in
extent
]
return
buffer_load_to_tile_region
(
tir
.
BufferLoad
(
data
,
zeros
),
access_type
,
extent
)
elif
isinstance
(
data
,
tir
.
BufferRegion
):
elif
isinstance
(
data
,
tir
.
BufferRegion
):
return
buffer_region_to_tile_region
(
data
,
access_type
,
extent
)
return
buffer_region_to_tile_region
(
data
,
access_type
,
extent
)
elif
isinstance
(
data
,
tir
.
BufferLoad
):
elif
isinstance
(
data
,
tir
.
BufferLoad
):
...
@@ -218,8 +219,8 @@ def atomic_add(dst: Buffer,
...
@@ -218,8 +219,8 @@ def atomic_add(dst: Buffer,
else
:
else
:
return
buffer_load_to_tile_region
(
data
,
access_type
,
extent
)
return
buffer_load_to_tile_region
(
data
,
access_type
,
extent
)
value
=
_to_region
(
value
,
"r"
)
value
=
_to_region
(
value
,
"r"
,
src_extent
)
dst
=
_to_region
(
dst
,
"w"
)
dst
=
_to_region
(
dst
,
"w"
,
dst_extent
)
# Note: tile-region-based atomic operations don't support return_prev yet
# Note: tile-region-based atomic operations don't support return_prev yet
# This would need to be implemented in the tile runtime
# This would need to be implemented in the tile runtime
...
...
tilelang/language/builtin.py
View file @
bbbf4207
...
@@ -5,9 +5,10 @@ from tilelang import tvm as tvm
...
@@ -5,9 +5,10 @@ from tilelang import tvm as tvm
from
tilelang.language
import
ptx_arrive_barrier
,
evaluate
from
tilelang.language
import
ptx_arrive_barrier
,
evaluate
from
tilelang.language.kernel
import
get_thread_bindings
,
get_block_extents
from
tilelang.language.kernel
import
get_thread_bindings
,
get_block_extents
from
tilelang.utils.target
import
check_hip_availability
from
tilelang.utils.target
import
check_hip_availability
from
tvm
import
tir
from
tvm
import
DataType
,
tir
from
tvm.runtime
import
convert
from
typing
import
Any
from
typing
import
Any
from
tvm.tir
import
PrimExpr
,
Var
,
Call
,
Buffer
,
Buffer
Load
from
tvm.tir
import
PrimExpr
,
Var
,
Call
,
Buffer
Load
,
Buffer
Region
_IS_HIP_AVAILABLE
=
check_hip_availability
()
_IS_HIP_AVAILABLE
=
check_hip_availability
()
...
@@ -429,6 +430,168 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
...
@@ -429,6 +430,168 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return
tir
.
call_intrin
(
"bool"
,
tir
.
op
.
Op
.
get
(
"tl.tl_shuffle_elect"
),
thread_extent
)
return
tir
.
call_intrin
(
"bool"
,
tir
.
op
.
Op
.
get
(
"tl.tl_shuffle_elect"
),
thread_extent
)
def
warpgroup_fence_operand
(
buffer_or_ptr
:
tir
.
Buffer
|
PrimExpr
,
offset
:
int
|
PrimExpr
=
0
,
num_regs
:
int
|
PrimExpr
|
None
=
None
,
dtype
:
str
|
None
=
None
):
"""Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding
WGMMA operations by issuing an empty inline assembly barrier on every register.
Args:
buffer_or_ptr: Buffer | BufferLoad | BufferRegion | PrimExpr
A buffer representing the accumulator fragment, a buffer load/region
that identifies a starting element within the fragment, or a pointer expression
(e.g., tvm_access_ptr/address_of/typed Var).
offset: int | PrimExpr
Element offset from the start of the accumulator fragment.
num_regs: int | PrimExpr | None
Number of 32-bit registers to fence. If None and a Buffer is provided, it will be
derived from the buffer shape and dtype.
dtype: str | None
Data type string of the accumulator elements. When passing a buffer or
buffer-derived expression, dtype is inferred. It is required only when
passing a raw pointer expression that cannot be inferred.
Returns:
tir.Call: A handle to the warpgroup fence operation.
"""
if
isinstance
(
buffer_or_ptr
,
BufferLoad
):
# Treat BufferLoad as a request to fence starting from the loaded element's address
buf
=
buffer_or_ptr
.
buffer
data_ptr
=
buf
.
data
inferred_dtype
=
buf
.
dtype
if
dtype
is
not
None
and
dtype
!=
inferred_dtype
:
raise
ValueError
(
f
"dtype mismatch: provided
{
dtype
}
, buffer uses
{
inferred_dtype
}
."
)
dtype
=
inferred_dtype
# Compute element offset from indices using strides if present, otherwise row-major
if
len
(
buf
.
strides
)
==
len
(
buf
.
shape
)
and
len
(
buf
.
strides
)
>
0
:
elem_off
=
0
for
idx
,
stride
in
zip
(
buffer_or_ptr
.
indices
,
buf
.
strides
):
elem_off
=
elem_off
+
idx
*
stride
else
:
elem_off
=
0
stride_acc
=
1
for
idx
,
dim
in
zip
(
reversed
(
buffer_or_ptr
.
indices
),
reversed
(
buf
.
shape
)):
elem_off
=
elem_off
+
idx
*
stride_acc
stride_acc
=
stride_acc
*
dim
# Combine with user-provided offset
offset
=
elem_off
+
convert
(
offset
)
if
num_regs
is
None
:
raise
ValueError
(
"num_regs must be provided when passing a BufferLoad."
)
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.warpgroup_fence_operand"
),
dtype
,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
if
isinstance
(
buffer_or_ptr
,
tir
.
Buffer
):
data_ptr
=
buffer_or_ptr
.
data
inferred_dtype
=
buffer_or_ptr
.
dtype
if
dtype
is
not
None
and
dtype
!=
inferred_dtype
:
raise
ValueError
(
f
"dtype mismatch: provided
{
dtype
}
, buffer uses
{
inferred_dtype
}
."
)
dtype
=
inferred_dtype
if
num_regs
is
None
:
total_elems
=
1
for
dim
in
buffer_or_ptr
.
shape
:
if
isinstance
(
dim
,
tir
.
IntImm
):
total_elems
*=
int
(
dim
)
else
:
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when buffer shape is symbolic."
)
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
elif
isinstance
(
buffer_or_ptr
,
BufferRegion
):
buf
=
buffer_or_ptr
.
buffer
data_ptr
=
buf
.
data
inferred_dtype
=
buf
.
dtype
if
dtype
is
not
None
and
dtype
!=
inferred_dtype
:
raise
ValueError
(
f
"dtype mismatch: provided
{
dtype
}
, buffer uses
{
inferred_dtype
}
."
)
dtype
=
inferred_dtype
# Compute element offset from region min using strides if present, otherwise row-major
if
len
(
buf
.
strides
)
==
len
(
buf
.
shape
)
and
len
(
buf
.
strides
)
>
0
:
elem_off
=
0
for
r
,
stride
in
zip
(
buffer_or_ptr
.
region
,
buf
.
strides
):
elem_off
=
elem_off
+
r
.
min
*
stride
else
:
elem_off
=
0
stride_acc
=
1
for
r
,
dim
in
zip
(
reversed
(
buffer_or_ptr
.
region
),
reversed
(
buf
.
shape
)):
elem_off
=
elem_off
+
r
.
min
*
stride_acc
stride_acc
=
stride_acc
*
dim
# Combine with user-provided offset
offset
=
elem_off
+
convert
(
offset
)
# Try derive num_regs from region extents if fully static; otherwise require user input
if
num_regs
is
None
:
total_elems
=
1
static
=
True
for
r
in
buffer_or_ptr
.
region
:
if
isinstance
(
r
.
extent
,
tir
.
IntImm
):
total_elems
*=
int
(
r
.
extent
)
else
:
static
=
False
break
if
static
:
bits_per_elem
=
DataType
(
dtype
).
bits
num_regs
=
(
total_elems
*
bits_per_elem
+
31
)
//
32
else
:
raise
ValueError
(
"warpgroup_fence_operand requires num_regs when BufferRegion extent is symbolic."
)
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.warpgroup_fence_operand"
),
dtype
,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
else
:
data_ptr
=
buffer_or_ptr
# Try to infer dtype from common pointer expressions when not provided
if
dtype
is
None
:
inferred
=
None
# Case 1: Pointer from Buffer.access_ptr -> tir.builtin.tvm_access_ptr
if
isinstance
(
data_ptr
,
Call
)
and
data_ptr
.
op
.
same_as
(
tir
.
builtin
.
tvm_access_ptr
()):
# args[0] is a type annotation call; its dtype carries the element dtype
inferred
=
str
(
data_ptr
.
args
[
0
].
dtype
)
# Case 2: Pointer from tir.address_of(BufferLoad(...))
elif
isinstance
(
data_ptr
,
Call
)
and
data_ptr
.
op
.
same_as
(
tir
.
builtin
.
address_of
()):
# args[0] should be a BufferLoad; its dtype is the element dtype
inferred
=
str
(
data_ptr
.
args
[
0
].
dtype
)
# Case 3: Typed pointer Var with PrimType element (typed TIR)
elif
hasattr
(
data_ptr
,
"type_annotation"
)
and
data_ptr
.
type_annotation
is
not
None
:
try
:
elem_ty
=
getattr
(
data_ptr
.
type_annotation
,
"element_type"
,
None
)
if
elem_ty
is
not
None
and
hasattr
(
elem_ty
,
"dtype"
):
inferred
=
str
(
elem_ty
.
dtype
)
except
Exception
:
inferred
=
None
if
inferred
is
None
:
raise
ValueError
(
"dtype must be provided when passing a pointer expression and cannot be inferred."
)
dtype
=
inferred
if
num_regs
is
None
:
raise
ValueError
(
"num_regs must be provided when passing a pointer expression."
)
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.warpgroup_fence_operand"
),
dtype
,
data_ptr
,
convert
(
offset
),
convert
(
num_regs
),
))
def
wait_wgmma
(
id
:
int
):
def
wait_wgmma
(
id
:
int
):
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
"""Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
...
@@ -537,38 +700,70 @@ def sync_grid():
...
@@ -537,38 +700,70 @@ def sync_grid():
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.sync_grid"
))
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.sync_grid"
))
def
initialize_descriptor
(
descriptor
:
Buffer
,
def
initialize_wgmma_descriptor
(
start_address
:
PrimExpr
,
descriptor
:
tir
.
Buffer
,
layout_type_
:
int
=
0
,
start_address
:
PrimExpr
,
leading_byte_offset
:
int
=
0
,
layout_type_
:
int
=
0
,
stride_byte_offset
:
int
=
0
)
->
PrimExpr
:
leading_byte_offset
:
int
=
0
,
"""
stride_byte_offset
:
int
=
0
,
Initialize a memory descriptor with the given parameters.
)
->
PrimExpr
:
"""Initialize a WGMMA/UTCMMA shared-memory descriptor."""
Parameters:
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
descriptor (Buffer): The memory descriptor to initialize.
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
start_address (PrimExpr): The starting address of the memory region.
layout_type_ (int, optional): Layout type identifier. Defaults to 0.
leading_byte_offset (int, optional): Leading byte offset. Defaults to 0.
stride_byte_offset (int, optional): Stride byte offset. Defaults to 0.
Returns:
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
PrimExpr: A handle representing the initialized descriptor.
descriptor
.
shape
[
0
]
!=
1
):
"""
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.initialize_wgmma_descriptor"
),
descriptor
,
start_address
,
layout_type_
,
int
(
leading_byte_offset
),
int
(
stride_byte_offset
),
))
def
initialize_tcgen05_descriptor
(
descriptor
:
tir
.
Buffer
,
start_address
:
PrimExpr
,
leading_byte_offset
:
int
,
stride_byte_offset
:
int
,
base_offset
:
int
=
0
,
leading_is_absolute
:
bool
=
False
,
swizzle_mode
:
int
=
0
,
)
->
PrimExpr
:
"""Initialize a TCGEN05 shared-memory descriptor."""
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
Buffer
)):
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
(
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
):
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
,
[
0
])
descriptor
,
[
0
])
return
evaluate
(
return
evaluate
(
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.initialize_descriptor"
),
descriptor
,
tir
.
call_intrin
(
start_address
,
layout_type_
,
int
(
leading_byte_offset
),
"handle"
,
int
(
stride_byte_offset
)))
tir
.
op
.
Op
.
get
(
"tl.initialize_tcgen05_descriptor"
),
descriptor
,
start_address
,
int
(
leading_byte_offset
),
int
(
stride_byte_offset
),
int
(
base_offset
),
tir
.
IntImm
(
"int32"
,
1
if
leading_is_absolute
else
0
),
int
(
swizzle_mode
),
))
def
increase_descriptor_offset
(
descriptor
:
PrimExpr
,
offset
:
PrimExpr
)
->
PrimExpr
:
def
increase_descriptor_offset
(
descriptor
:
PrimExpr
,
offset
:
PrimExpr
)
->
PrimExpr
:
...
@@ -582,10 +777,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
...
@@ -582,10 +777,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
Returns:
Returns:
PrimExpr: A handle representing the modified descriptor.
PrimExpr: A handle representing the modified descriptor.
"""
"""
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
Buffer
)):
if
not
isinstance
(
descriptor
,
(
BufferLoad
,
tir
.
Buffer
)):
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
raise
TypeError
(
"Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad."
)
if
isinstance
(
descriptor
,
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
if
isinstance
(
descriptor
,
tir
.
Buffer
)
and
len
(
descriptor
.
shape
)
!=
1
or
descriptor
.
shape
[
0
]
!=
1
:
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
raise
ValueError
(
"Descriptor must be a 1D buffer of size 1."
)
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
descriptor
=
descriptor
if
isinstance
(
descriptor
,
BufferLoad
)
else
tir
.
BufferLoad
(
...
@@ -606,3 +802,113 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
...
@@ -606,3 +802,113 @@ def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
"""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ptx_cp_async_barrier_noinc"
),
barrier_id
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ptx_cp_async_barrier_noinc"
),
barrier_id
)
def
tcgen05_mma_arrive
(
mbar_ptr
):
"""Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
Parameters
----------
mbar_ptr : PrimExpr
Pointer to the mbarrier object in shared memory (e.g., Barrier*).
"""
return
tir
.
call_intrin
(
"void"
,
tir
.
op
.
Op
.
get
(
"tl.tcgen05_mma_arrive"
),
mbar_ptr
)
def
ptx_mma_sm70
(
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
multiplicand_a
,
a_index
,
multiplicand_b
,
b_index
,
accumulator
,
c_index
,
):
"""TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).
This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape
with FP16 inputs and FP16/FP32 accumulation.
Parameters
----------
shape : str
The shape of mma fragment (e.g., "m16n16k4").
A_layout : str
The layout of multiplicand fragment A ("row" or "col").
B_layout : str
The layout of multiplicand fragment B ("row" or "col").
A_dtype : str
The data type of multiplicand fragment A (typically "fp16").
B_dtype : str
The data type of multiplicand fragment B (typically "fp16").
C_dtype : str
The data type of accumulator fragment C ("fp16" or "fp32").
multiplicand_a : Var
The multiplicand fragment A variable.
a_index : Expr
The index of multiplicand fragment A.
multiplicand_b : Var
The multiplicand fragment B variable.
b_index : Expr
The index of multiplicand fragment B.
accumulator : Var
The accumulator fragment C variable.
c_index : Expr
The index of accumulator fragment C.
Returns
-------
call : PrimExpr
The call expression.
Examples
--------
>>> T.ptx_mma_sm70(
... "float16",
... "m16n16k4",
... "row",
... "col",
... "fp16",
... "fp16",
... "fp16",
... A_local.data,
... 0,
... B_local.data,
... 0,
... C_local.data,
... 0,
... )
"""
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.ptx_mma_sm70"
),
shape
,
A_layout
,
B_layout
,
A_dtype
,
B_dtype
,
C_dtype
,
multiplicand_a
,
a_index
,
multiplicand_b
,
b_index
,
accumulator
,
c_index
,
)
tilelang/language/copy.py
View file @
bbbf4207
...
@@ -3,9 +3,12 @@ from __future__ import annotations
...
@@ -3,9 +3,12 @@ from __future__ import annotations
from
typing
import
Literal
from
typing
import
Literal
from
tilelang
import
language
as
T
from
tilelang
import
language
as
T
from
tilelang.utils.language
import
get_buffer_region_from_load
from
tilelang.utils.language
import
(
get_buffer_region_from_load
,
legalize_pairwise_extents
,
)
from
tvm
import
ir
,
tir
from
tvm
import
ir
,
tir
from
tilelang.language.utils
import
buffer_to_tile_region
,
buffer_region_to_tile_region
,
buffer_load_to_tile_region
from
tilelang.language.utils
import
buffer_region_to_tile_region
,
buffer_load_to_tile_region
def
copy
(
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
def
copy
(
src
:
tir
.
Buffer
|
tir
.
BufferLoad
|
tir
.
BufferRegion
,
...
@@ -55,15 +58,26 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
...
@@ -55,15 +58,26 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
return
tir
.
BufferStore
(
dst
.
buffer
,
src
,
dst
.
indices
)
return
tir
.
BufferStore
(
dst
.
buffer
,
src
,
dst
.
indices
)
assert
src_extent
or
dst_extent
,
"Can't deduce copy extents from args"
assert
src_extent
or
dst_extent
,
"Can't deduce copy extents from args"
# Treat missing extent as length-matched ones to enable broadcasting logic.
src_extent
=
list
(
src_extent
)
if
src_extent
else
[
1
]
*
len
(
dst_extent
)
src_extent
=
list
(
src_extent
)
if
src_extent
else
[
1
]
*
len
(
dst_extent
)
dst_extent
=
list
(
dst_extent
)
if
dst_extent
else
[
1
]
*
len
(
src_extent
)
dst_extent
=
list
(
dst_extent
)
if
dst_extent
else
[
1
]
*
len
(
src_extent
)
extent
=
max
(
src_extent
,
dst_extent
)
def
_to_region
(
data
,
access_type
):
# Align and broadcast extents from the right (tail) side independently
# for src and dst, so we can pass them unchanged into _to_region.
# Rules per-dim from the right:
# - equal -> keep both
# - one is 1 -> set that side to the other side's dim
# - otherwise -> error
src_extent
,
dst_extent
=
legalize_pairwise_extents
(
src_extent
,
dst_extent
)
def
_to_region
(
data
,
access_type
,
extent
):
if
isinstance
(
data
,
tir
.
Var
)
and
T
.
has_let_value
(
data
):
if
isinstance
(
data
,
tir
.
Var
)
and
T
.
has_let_value
(
data
):
data
=
T
.
get_let_value
(
data
)
data
=
T
.
get_let_value
(
data
)
if
isinstance
(
data
,
tir
.
Buffer
):
if
isinstance
(
data
,
tir
.
Buffer
):
return
buffer_to_tile_region
(
data
,
access_type
)
# Restrict a raw buffer to the computed copy extent by creating
# a BufferLoad at origin and passing the extents explicitly.
zeros
=
[
tir
.
IntImm
(
"int32"
,
0
)
for
_
in
extent
]
return
buffer_load_to_tile_region
(
tir
.
BufferLoad
(
data
,
zeros
),
access_type
,
extent
)
elif
isinstance
(
data
,
tir
.
BufferRegion
):
elif
isinstance
(
data
,
tir
.
BufferRegion
):
return
buffer_region_to_tile_region
(
data
,
access_type
,
extent
)
return
buffer_region_to_tile_region
(
data
,
access_type
,
extent
)
elif
isinstance
(
data
,
tir
.
BufferLoad
):
elif
isinstance
(
data
,
tir
.
BufferLoad
):
...
@@ -74,8 +88,9 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
...
@@ -74,8 +88,9 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
else
:
else
:
return
buffer_load_to_tile_region
(
data
,
access_type
,
extent
)
return
buffer_load_to_tile_region
(
data
,
access_type
,
extent
)
src
=
_to_region
(
src
,
"r"
)
# Use legalized extents for src and dst respectively.
dst
=
_to_region
(
dst
,
"w"
)
src
=
_to_region
(
src
,
"r"
,
src_extent
)
dst
=
_to_region
(
dst
,
"w"
,
dst_extent
)
if
coalesced_width
is
None
:
if
coalesced_width
is
None
:
coalesced_width
=
-
1
# PrimExpr can not be None
coalesced_width
=
-
1
# PrimExpr can not be None
...
...
tilelang/language/fill.py
View file @
bbbf4207
...
@@ -4,9 +4,14 @@ from __future__ import annotations
...
@@ -4,9 +4,14 @@ from __future__ import annotations
from
tvm
import
tir
from
tvm
import
tir
from
tilelang.language
import
has_let_value
,
get_let_value
from
tilelang.language
import
has_let_value
,
get_let_value
from
tilelang.utils.language
import
get_buffer_region_from_load
from
tilelang.utils.language
import
get_buffer_region_from_load
from
tilelang.language.utils
import
(
buffer_to_tile_region
,
buffer_region_to_tile_region
,
buffer_load_to_tile_region
,
)
def
fill
(
buffer
:
tir
.
Buffer
|
tir
.
BufferRegion
,
value
:
tir
.
PrimExpr
):
def
fill
(
buffer
:
tir
.
Buffer
|
tir
.
BufferRegion
|
tir
.
BufferLoad
,
value
:
tir
.
PrimExpr
):
"""Fill a buffer or buffer region with a specified value.
"""Fill a buffer or buffer region with a specified value.
Args:
Args:
...
@@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
...
@@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
Returns:
Returns:
A TVM intrinsic call that performs the fill operation
A TVM intrinsic call that performs the fill operation
"""
"""
# Normalize Var with let value to its underlying object
if
isinstance
(
buffer
,
tir
.
Var
)
and
has_let_value
(
buffer
):
buffer
=
get_let_value
(
buffer
)
# Convert to a tl.region descriptor (PrimExpr) with write access
region_call
=
None
if
isinstance
(
buffer
,
tir
.
Buffer
):
if
isinstance
(
buffer
,
tir
.
Buffer
):
buffer
=
buffer
.
access_ptr
(
"w"
)
# Get write pointer if input is a Buffer
region_call
=
buffer_to_tile_region
(
buffer
,
"w"
)
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.fill"
),
buffer
,
value
)
elif
isinstance
(
buffer
,
tir
.
BufferRegion
):
extents
=
[
r
.
extent
for
r
in
buffer
.
region
]
region_call
=
buffer_region_to_tile_region
(
buffer
,
"w"
,
extents
)
elif
isinstance
(
buffer
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
buffer
)
if
region
is
not
None
:
extents
=
[
r
.
extent
for
r
in
region
.
region
]
region_call
=
buffer_region_to_tile_region
(
region
,
"w"
,
extents
)
else
:
# Fallback: treat element access as 1-extent per dim
region_call
=
buffer_load_to_tile_region
(
buffer
,
"w"
,
[
1
]
*
len
(
buffer
.
indices
))
else
:
# As-is fallback (rare): pass through for downstream handling
region_call
=
buffer
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.fill"
),
region_call
,
value
)
def
clear
(
buffer
:
tir
.
Buffer
|
tir
.
Var
):
def
clear
(
buffer
:
tir
.
Buffer
|
tir
.
Var
):
...
...
tilelang/language/gemm.py
View file @
bbbf4207
...
@@ -4,10 +4,19 @@ from __future__ import annotations
...
@@ -4,10 +4,19 @@ from __future__ import annotations
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
from
tilelang.primitives.gemm.base
import
GemmWarpPolicy
import
tilelang.language
as
T
import
tilelang.language
as
T
from
tvm
import
tir
from
tvm
import
tir
from
tilelang.utils.language
import
get_buffer_region_from_load
from
tilelang.utils.language
import
(
to_buffer_region
,
retrieve_shape
,
def
gemm
(
retrieve_stride
,
retrieve_ptr
,
retrieve_offset
,
prim_expr_equal
,
)
from
tilelang.env
import
env
as
_env
def
_gemm_impl
(
op_key
:
str
,
A
:
tir
.
Buffer
|
tir
.
Var
,
A
:
tir
.
Buffer
|
tir
.
Var
,
B
:
tir
.
Buffer
|
tir
.
Var
,
B
:
tir
.
Buffer
|
tir
.
Var
,
C
:
tir
.
Buffer
|
tir
.
Var
,
C
:
tir
.
Buffer
|
tir
.
Var
,
...
@@ -19,30 +28,9 @@ def gemm(
...
@@ -19,30 +28,9 @@ def gemm(
wg_wait
:
int
=
0
,
wg_wait
:
int
=
0
,
mbar
:
tir
.
Buffer
|
None
=
None
,
mbar
:
tir
.
Buffer
|
None
=
None
,
):
):
"""Perform a General Matrix Multiplication (GEMM) operation.
"""Shared GEMM implementation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
Returns a call_intrin handle for the given op key.
A (Union[tir.Buffer, tir.Var]): First input matrix
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
On hopper it is equivalent to `wgmma.wait_group.sync.aligned <wg_wait>` if wg_wait is not -1
On sm100, `wg_wait` can only be 0 or -1. `mbarrier_wait(TCGEN5MMA barrier)` will be appended if wg_wait is 0.
mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
"""
def
legalize_arguments
(
arg
:
tir
.
Buffer
|
tir
.
Var
):
def
legalize_arguments
(
arg
:
tir
.
Buffer
|
tir
.
Var
):
...
@@ -63,52 +51,10 @@ def gemm(
...
@@ -63,52 +51,10 @@ def gemm(
C
=
legalize_arguments
(
C
)
C
=
legalize_arguments
(
C
)
mbar
=
legalize_arguments
(
mbar
)
if
mbar
is
not
None
else
None
mbar
=
legalize_arguments
(
mbar
)
if
mbar
is
not
None
else
None
def
retrieve_shape
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
)
->
list
[
int
]:
# Normalize A/B/C to BufferRegion to pass into tl.gemm
if
isinstance
(
object
,
tir
.
Buffer
):
A
=
to_buffer_region
(
A
)
return
object
.
shape
B
=
to_buffer_region
(
B
)
elif
isinstance
(
object
,
tir
.
BufferRegion
):
C
=
to_buffer_region
(
C
)
region
=
object
.
region
shape
=
[]
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
shape
=
[]
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
else
:
raise
ValueError
(
f
"Unsupported retrieve_shape argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_stride
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
)
->
list
[
int
]:
if
isinstance
(
object
,
tir
.
Buffer
):
strides
=
[]
stride
=
1
for
s
in
reversed
(
object
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
elif
isinstance
(
object
,
tir
.
BufferRegion
):
buffer
,
_
=
object
.
buffer
,
object
.
region
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
else
:
raise
ValueError
(
f
"Unsupported retrieve_stride argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_shape
=
retrieve_shape
(
A
)
A_shape
=
retrieve_shape
(
A
)
B_shape
=
retrieve_shape
(
B
)
B_shape
=
retrieve_shape
(
B
)
...
@@ -132,68 +78,11 @@ def gemm(
...
@@ -132,68 +78,11 @@ def gemm(
M
,
N
=
C_shape
M
,
N
=
C_shape
K
=
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
]
K
=
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
]
K_B
=
B_shape
[
-
1
]
if
transpose_B
else
B_shape
[
-
2
]
K_B
=
B_shape
[
-
1
]
if
transpose_B
else
B_shape
[
-
2
]
assert
K
==
K_B
,
f
"T.gemm K shape check failed: K_A =
{
K
}
, K_B =
{
K_B
}
"
assert
prim_expr_equal
(
K
,
K_B
)
,
f
"T.gemm K shape check failed: K_A =
{
K
}
, K_B =
{
K_B
}
"
stride_a
=
A_stride
[
-
2
]
stride_a
=
A_stride
[
-
2
]
stride_b
=
B_stride
[
-
2
]
stride_b
=
B_stride
[
-
2
]
def
retrieve_ptr
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
,
access_type
:
str
=
"r"
)
->
tir
.
PrimExpr
:
if
isinstance
(
object
,
tir
.
Buffer
):
return
object
.
access_ptr
(
access_type
)
elif
isinstance
(
object
,
tir
.
BufferRegion
):
buffer
,
region
=
object
.
buffer
,
object
.
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
offset
=
0
# not offset the last two dimension
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
offset
=
0
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
else
:
raise
ValueError
(
f
"Unsupported retrieve_ptr argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_offset
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
)
->
tir
.
PrimExpr
:
"""Retrieve the offset of the buffer or buffer region."""
if
isinstance
(
object
,
tir
.
Buffer
):
return
[
0
]
*
len
(
object
.
shape
)
elif
isinstance
(
object
,
tir
.
BufferRegion
):
_
,
region
=
object
.
buffer
,
object
.
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
else
:
raise
ValueError
(
f
"Unsupported retrieve_offset argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_offset
=
retrieve_offset
(
A
)
A_offset
=
retrieve_offset
(
A
)
B_offset
=
retrieve_offset
(
B
)
B_offset
=
retrieve_offset
(
B
)
assert
A_offset
[
-
2
]
==
0
,
"The offset of the first dimension of A must be 0"
assert
A_offset
[
-
2
]
==
0
,
"The offset of the first dimension of A must be 0"
...
@@ -201,18 +90,15 @@ def gemm(
...
@@ -201,18 +90,15 @@ def gemm(
offset_a
=
A_offset
[
-
1
]
offset_a
=
A_offset
[
-
1
]
offset_b
=
B_offset
[
-
1
]
offset_b
=
B_offset
[
-
1
]
Aptr
=
retrieve_ptr
(
A
,
"r"
)
Bptr
=
retrieve_ptr
(
B
,
"r"
)
Cptr
=
retrieve_ptr
(
C
,
"rw"
)
mbarptr
=
retrieve_ptr
(
mbar
,
"rw"
)
if
mbar
is
not
None
else
tir
.
const
(
0
,
"uint32"
)
mbarptr
=
retrieve_ptr
(
mbar
,
"rw"
)
if
mbar
is
not
None
else
tir
.
const
(
0
,
"uint32"
)
C_coords
=
[
r
.
min
for
r
in
C
.
region
]
if
isinstance
(
C
,
tir
.
BufferRegion
)
else
[
0
,
0
]
C_coords
=
[
r
.
min
for
r
in
C
.
region
]
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.gemm"
),
Aptr
,
Bptr
,
Cptr
,
transpose_
A
,
return
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
op_key
),
A
,
B
,
C
,
transpose_A
,
transpose_
B
,
M
,
N
,
transpose_B
,
M
,
N
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
K
,
policy
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
offset_b
,
k_pack
,
offset_b
,
k_pack
,
wg_wait
,
mbarptr
,
C_coords
[
0
],
C_coords
[
1
])
wg_wait
,
mbarptr
,
C_coords
[
0
],
C_coords
[
1
])
#
experimental currently, for fast compilation
#
Public wrappers
def
gemm_v
2
(
def
gemm_v
1
(
A
:
tir
.
Buffer
|
tir
.
Var
,
A
:
tir
.
Buffer
|
tir
.
Var
,
B
:
tir
.
Buffer
|
tir
.
Var
,
B
:
tir
.
Buffer
|
tir
.
Var
,
C
:
tir
.
Buffer
|
tir
.
Var
,
C
:
tir
.
Buffer
|
tir
.
Var
,
...
@@ -222,205 +108,52 @@ def gemm_v2(
...
@@ -222,205 +108,52 @@ def gemm_v2(
clear_accum
:
bool
=
False
,
clear_accum
:
bool
=
False
,
k_pack
:
int
=
1
,
k_pack
:
int
=
1
,
wg_wait
:
int
=
0
,
wg_wait
:
int
=
0
,
mbar
:
tir
.
Buffer
|
None
=
None
,
):
):
"""Perform a General Matrix Multiplication (GEMM) operation.
"""GEMM v1: use op tl.gemm."""
return
_gemm_impl
(
This function computes C = A @ B where A and B can optionally be transposed.
"tl.gemm"
,
The operation supports various warp policies and accumulation modes.
A
,
B
,
Args:
C
,
A (Union[tir.Buffer, tir.Var]): First input matrix
transpose_A
,
B (Union[tir.Buffer, tir.Var]): Second input matrix
transpose_B
,
C (Union[tir.Buffer, tir.Var]): Output matrix for results
policy
,
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
clear_accum
,
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
k_pack
,
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
wg_wait
,
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
mbar
,
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
)
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
def
legalize_arguments
(
arg
:
tir
.
Buffer
|
tir
.
Var
):
"""Convert let-bound variables to their corresponding buffers.
Args:
arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
Returns:
Union[tir.Buffer, tir.Var]: The legalized argument
"""
if
isinstance
(
arg
,
tir
.
Var
)
and
T
.
has_let_value
(
arg
):
return
T
.
get_let_value
(
arg
).
buffer
return
arg
A
=
legalize_arguments
(
A
)
B
=
legalize_arguments
(
B
)
C
=
legalize_arguments
(
C
)
def
retrieve_shape
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
)
->
list
[
int
]:
if
isinstance
(
object
,
tir
.
Buffer
):
return
object
.
shape
elif
isinstance
(
object
,
tir
.
BufferRegion
):
region
=
object
.
region
shape
=
[]
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
shape
=
[]
for
r
in
region
:
shape
.
append
(
r
.
extent
)
return
shape
else
:
raise
ValueError
(
f
"Unsupported retrieve_shape argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_stride
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
)
->
list
[
int
]:
if
isinstance
(
object
,
tir
.
Buffer
):
strides
=
[]
stride
=
1
for
s
in
reversed
(
object
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
elif
isinstance
(
object
,
tir
.
BufferRegion
):
buffer
,
_
=
object
.
buffer
,
object
.
region
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
return
strides
else
:
raise
ValueError
(
f
"Unsupported retrieve_stride argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_shape
=
retrieve_shape
(
A
)
B_shape
=
retrieve_shape
(
B
)
C_shape
=
retrieve_shape
(
C
)
A_stride
=
retrieve_stride
(
A
)
B_stride
=
retrieve_stride
(
B
)
assert
len
(
C_shape
)
==
2
,
"current only support C as a 2D tensor"
assert
len
(
A_shape
)
>=
2
,
"current only support A as a 2D or higher-order tensor"
assert
len
(
B_shape
)
>=
2
,
"current only support B as a 2D or higher-order tensor"
if
len
(
A_shape
)
>
2
:
for
i
in
range
(
len
(
A_shape
)
-
2
):
assert
A_shape
[
i
]
==
1
,
\
"current only support A as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
if
len
(
B_shape
)
>
2
:
for
i
in
range
(
len
(
B_shape
)
-
2
):
assert
B_shape
[
i
]
==
1
,
\
"current only support B as a 2D or higher-order tensor with the last two dimensions being the matrix dimensions"
M
,
N
=
C_shape
K
=
A_shape
[
-
2
]
if
transpose_A
else
A_shape
[
-
1
]
K_B
=
B_shape
[
-
1
]
if
transpose_B
else
B_shape
[
-
2
]
assert
K
==
K_B
,
f
"T.gemm K shape check failed: K_A =
{
K
}
, K_B =
{
K_B
}
"
stride_a
=
A_stride
[
-
2
]
stride_b
=
B_stride
[
-
2
]
def
retrieve_ptr
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
,
access_type
:
str
=
"r"
)
->
tir
.
PrimExpr
:
if
isinstance
(
object
,
tir
.
Buffer
):
return
object
.
access_ptr
(
access_type
)
elif
isinstance
(
object
,
tir
.
BufferRegion
):
buffer
,
region
=
object
.
buffer
,
object
.
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
offset
=
0
# not offset the last two dimension
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
elif
isinstance
(
object
,
tir
.
BufferLoad
):
buffer
=
object
.
buffer
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
strides
=
[]
stride
=
1
for
s
in
reversed
(
buffer
.
shape
):
strides
.
insert
(
0
,
stride
)
stride
*=
s
offset
=
0
for
i
in
range
(
len
(
indices
)
-
2
):
offset
+=
indices
[
i
]
*
strides
[
i
]
return
buffer
.
access_ptr
(
access_mask
=
access_type
,
offset
=
offset
)
else
:
raise
ValueError
(
f
"Unsupported retrieve_ptr argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
def
retrieve_offset
(
object
:
tir
.
Buffer
|
tir
.
BufferRegion
)
->
tir
.
PrimExpr
:
"""Retrieve the offset of the buffer or buffer region."""
if
isinstance
(
object
,
tir
.
Buffer
):
return
[
0
]
*
len
(
object
.
shape
)
elif
isinstance
(
object
,
tir
.
BufferRegion
):
_
,
region
=
object
.
buffer
,
object
.
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
elif
isinstance
(
object
,
tir
.
BufferLoad
):
region
=
get_buffer_region_from_load
(
object
).
region
indices
=
[]
for
r
in
region
:
indices
.
append
(
r
.
min
)
return
indices
else
:
raise
ValueError
(
f
"Unsupported retrieve_offset argument type:
{
type
(
object
)
}
for buffer
{
object
}
"
)
A_offset
=
retrieve_offset
(
A
)
B_offset
=
retrieve_offset
(
B
)
assert
A_offset
[
-
2
]
==
0
,
"The offset of the first dimension of A must be 0"
assert
B_offset
[
-
2
]
==
0
,
"The offset of the first dimension of B must be 0"
offset_a
=
A_offset
[
-
1
]
offset_b
=
B_offset
[
-
1
]
Aptr
=
retrieve_ptr
(
A
,
"r"
)
# experimental currently, for fast compilation
Bptr
=
retrieve_ptr
(
B
,
"r"
)
def
gemm_v2
(
Cptr
=
retrieve_ptr
(
C
,
"rw"
)
A
:
tir
.
Buffer
|
tir
.
Var
,
return
tir
.
call_intrin
(
B
:
tir
.
Buffer
|
tir
.
Var
,
"handle"
,
C
:
tir
.
Buffer
|
tir
.
Var
,
tir
.
op
.
Op
.
get
(
"tl.gemm_py"
),
transpose_A
:
bool
=
False
,
Aptr
,
transpose_B
:
bool
=
False
,
Bptr
,
policy
:
GemmWarpPolicy
=
GemmWarpPolicy
.
Square
,
Cptr
,
clear_accum
:
bool
=
False
,
k_pack
:
int
=
1
,
wg_wait
:
int
=
0
,
mbar
:
tir
.
Buffer
|
None
=
None
,
):
"""GEMM v2: use op tl.gemm_py."""
return
_gemm_impl
(
"tl.gemm_py"
,
A
,
B
,
C
,
transpose_A
,
transpose_A
,
transpose_B
,
transpose_B
,
M
,
N
,
K
,
policy
,
policy
,
clear_accum
,
clear_accum
,
stride_a
,
stride_b
,
offset_a
,
offset_b
,
k_pack
,
k_pack
,
wg_wait
,
wg_wait
,
mbar
,
)
)
# Default to v2; allow forcing v1 via environment variable
gemm
=
gemm_v1
if
_env
.
use_gemm_v1
()
else
gemm_v2
tilelang/language/
pipeline
.py
→
tilelang/language/
loop
.py
View file @
bbbf4207
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
from
tvm
import
tir
from
tvm
import
tir
from
tvm.tir
import
IntImm
from
tvm.tir
import
IntImm
import
tvm.script.ir_builder.tir
as
tb_tir
from
.v2.builder
import
SerialForWithStep
from
tilelang
import
_ffi_api
from
tilelang
import
_ffi_api
def
Parallel
(
*
extents
:
tir
.
PrimExpr
,
coalesced_width
:
int
|
None
=
None
):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations
:
dict
[
str
,
Any
]
=
{}
if
coalesced_width
is
not
None
:
annotations
.
update
({
"coalesced_width"
:
coalesced_width
})
return
_ffi_api
.
Parallel
(
extents
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
def
Persistent
(
domain
:
list
[
tir
.
PrimExpr
],
wave_size
:
tir
.
PrimExpr
,
index
:
tir
.
PrimExpr
,
group_size
:
tir
.
PrimExpr
|
None
=
8
,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return
_ffi_api
.
Persistent
(
domain
,
wave_size
,
index
,
group_size
)
def
Pipelined
(
def
Pipelined
(
start
:
tir
.
PrimExpr
,
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
=
None
,
stop
:
tir
.
PrimExpr
=
None
,
...
@@ -44,3 +92,20 @@ def Pipelined(
...
@@ -44,3 +92,20 @@ def Pipelined(
group
=
[]
group
=
[]
# type: ignore[attr-defined] # pylint: disable=no-member
# type: ignore[attr-defined] # pylint: disable=no-member
return
_ffi_api
.
Pipelined
(
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
return
_ffi_api
.
Pipelined
(
start
,
stop
,
num_stages
,
order
,
stage
,
sync
,
group
)
def
serial
(
start
:
tir
.
PrimExpr
,
stop
:
tir
.
PrimExpr
|
None
=
None
,
step
:
tir
.
PrimExpr
|
None
=
None
,
*
,
annotations
:
dict
[
str
,
Any
]
|
None
=
None
):
step_is_one
=
False
step_is_one
|=
isinstance
(
step
,
int
)
and
step
==
1
step_is_one
|=
isinstance
(
step
,
IntImm
)
and
step
.
value
==
1
if
step
is
None
or
step_is_one
:
return
tb_tir
.
serial
(
start
,
stop
,
annotations
=
annotations
)
else
:
if
stop
is
None
:
stop
=
start
start
=
IntImm
(
start
.
dtype
,
0
)
if
hasattr
(
start
,
"dtype"
)
else
0
return
SerialForWithStep
(
start
,
stop
,
step
,
annotations
=
annotations
)
tilelang/language/parallel.py
deleted
100644 → 0
View file @
8f4628e0
"""The language interface for tl programs."""
from
__future__
import
annotations
from
typing
import
Any
from
tvm
import
tir
from
tilelang
import
_ffi_api
def
Parallel
(
*
extents
:
tir
.
PrimExpr
,
coalesced_width
:
int
|
None
=
None
):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations
:
dict
[
str
,
Any
]
=
{}
if
coalesced_width
is
not
None
:
annotations
.
update
({
"coalesced_width"
:
coalesced_width
})
return
_ffi_api
.
Parallel
(
extents
,
annotations
)
# type: ignore[attr-defined] # pylint: disable=no-member
tilelang/language/persistent.py
deleted
100644 → 0
View file @
8f4628e0
"""The language interface for tl programs."""
from
__future__
import
annotations
from
tvm
import
tir
from
tilelang
import
_ffi_api
def
Persistent
(
domain
:
list
[
tir
.
PrimExpr
],
wave_size
:
tir
.
PrimExpr
,
index
:
tir
.
PrimExpr
,
group_size
:
tir
.
PrimExpr
|
None
=
8
,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return
_ffi_api
.
Persistent
(
domain
,
wave_size
,
index
,
group_size
)
tilelang/language/print.py
View file @
bbbf4207
...
@@ -5,6 +5,7 @@ It includes functionality to print variables, print values in buffers, condition
...
@@ -5,6 +5,7 @@ It includes functionality to print variables, print values in buffers, condition
from
tvm
import
tir
from
tvm
import
tir
from
typing
import
Any
from
typing
import
Any
import
tilelang.language
as
T
from
tilelang.language.kernel
import
get_thread_bindings
from
tilelang.language.kernel
import
get_thread_bindings
from
tilelang.language
import
copy
,
macro
,
serial
,
alloc_shared
from
tilelang.language
import
copy
,
macro
,
serial
,
alloc_shared
from
tilelang.language.utils
import
index_to_coordinates
from
tilelang.language.utils
import
index_to_coordinates
...
@@ -148,10 +149,10 @@ def device_assert(condition: tir.PrimExpr, msg: str = ""):
...
@@ -148,10 +149,10 @@ def device_assert(condition: tir.PrimExpr, msg: str = ""):
"""
"""
if
_IS_CUDA_AVAILABLE
:
if
_IS_CUDA_AVAILABLE
:
if
msg
==
""
:
if
msg
==
""
:
tir
.
call_
exter
n
(
"void"
,
"
device_assert"
,
condition
)
T
.
call_
intri
n
(
"void"
,
tir
.
op
.
Op
.
get
(
"tl.
device_assert"
)
,
condition
)
else
:
else
:
warnings
.
warn
(
"Non-empty msg may slightly slow down the kernel"
,
stacklevel
=
2
)
warnings
.
warn
(
"Non-empty msg may slightly slow down the kernel"
,
stacklevel
=
2
)
tir
.
call_
exter
n
(
"void"
,
"
device_assert_with_msg"
,
condition
,
msg
)
T
.
call_
intri
n
(
"void"
,
tir
.
op
.
Op
.
get
(
"tl.
device_assert_with_msg"
)
,
condition
,
msg
)
def
print
(
obj
:
Any
,
msg
:
str
=
""
,
warp_group_id
:
int
=
0
,
warp_id
:
int
=
0
)
->
tir
.
PrimExpr
:
def
print
(
obj
:
Any
,
msg
:
str
=
""
,
warp_group_id
:
int
=
0
,
warp_id
:
int
=
0
)
->
tir
.
PrimExpr
:
...
...
tilelang/language/proxy.py
View file @
bbbf4207
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
,
Sequence
,
SupportsIndex
,
TYPE_CHECKING
from
typing
import
Any
,
SupportsIndex
,
TYPE_CHECKING
from
collections.abc
import
Sequence
from
typing_extensions
import
Self
from
typing_extensions
import
Self
from
tvm
import
tir
from
tvm
import
tir
...
...
tilelang/language/reduce.py
View file @
bbbf4207
...
@@ -2,7 +2,10 @@
...
@@ -2,7 +2,10 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
tvm
import
tir
from
tvm
import
tir
from
tilelang.language
import
copy
,
macro
,
alloc_shared
from
tilelang.language
import
copy
,
macro
,
alloc_shared
,
alloc_fragment
from
tilelang.language.utils
import
buffer_to_tile_region
from
tilelang.utils.language
import
is_shared
,
is_fragment
from
tvm.script.ir_builder
import
IRBuilder
def
_legalize_dim
(
buffer
:
tir
.
Buffer
,
dim
:
int
):
def
_legalize_dim
(
buffer
:
tir
.
Buffer
,
dim
:
int
):
...
@@ -34,17 +37,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
...
@@ -34,17 +37,70 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
raise
ValueError
(
raise
ValueError
(
f
"Invalid reduce output shape, buffer shape is
{
buffer
.
shape
}
, dim is
{
dim
}
, "
f
"Invalid reduce output shape, buffer shape is
{
buffer
.
shape
}
, dim is
{
dim
}
, "
f
"output shape is
{
out
.
shape
}
, expected shapes are
{
expected_shapes_str
}
"
)
f
"output shape is
{
out
.
shape
}
, expected shapes are
{
expected_shapes_str
}
"
)
buffer
=
buffer
.
access_ptr
(
"r"
)
out
=
out
.
access_ptr
(
"w"
)
@
macro
return
tir
.
call_intrin
(
def
reduce_macro
(
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
reduce_type
:
str
,
dim
:
int
,
clear
:
bool
):
"handle"
,
if
is_shared
(
buffer
)
and
is_shared
(
out
):
tir
.
op
.
Op
.
get
(
"tl.reduce"
),
red_frag_in
=
alloc_fragment
(
buffer
.
shape
,
buffer
.
dtype
)
buffer
,
red_frag_out
=
alloc_fragment
(
out
.
shape
,
out
.
dtype
)
out
,
reduce_type
,
# rename buffers
dim
,
IRBuilder
.
name
(
buffer
.
name
+
"_frag"
,
red_frag_in
)
clear
,
IRBuilder
.
name
(
out
.
name
+
"_frag"
,
red_frag_out
)
)
copy
(
buffer
,
red_frag_in
)
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.reduce"
),
buffer_to_tile_region
(
red_frag_in
,
"r"
),
buffer_to_tile_region
(
red_frag_out
,
"w"
),
reduce_type
,
dim
,
clear
,
)
copy
(
red_frag_out
,
out
)
elif
is_shared
(
buffer
)
and
is_fragment
(
out
):
red_frag_in
=
alloc_fragment
(
buffer
.
shape
,
buffer
.
dtype
)
IRBuilder
.
name
(
buffer
.
name
+
"_frag"
,
red_frag_in
)
copy
(
buffer
,
red_frag_in
)
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.reduce"
),
buffer_to_tile_region
(
red_frag_in
,
"r"
),
buffer_to_tile_region
(
out
,
"w"
),
reduce_type
,
dim
,
clear
,
)
elif
is_fragment
(
buffer
)
and
is_shared
(
out
):
red_frag_out
=
alloc_fragment
(
out
.
shape
,
out
.
dtype
)
IRBuilder
.
name
(
out
.
name
+
"_frag"
,
red_frag_out
)
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.reduce"
),
buffer_to_tile_region
(
buffer
,
"r"
),
buffer_to_tile_region
(
red_frag_out
,
"w"
),
reduce_type
,
dim
,
clear
,
)
copy
(
red_frag_out
,
out
)
elif
is_fragment
(
buffer
)
and
is_fragment
(
out
):
tir
.
call_intrin
(
"handle"
,
tir
.
op
.
Op
.
get
(
"tl.reduce"
),
buffer_to_tile_region
(
buffer
,
"r"
),
buffer_to_tile_region
(
out
,
"w"
),
reduce_type
,
dim
,
clear
,
)
else
:
raise
ValueError
(
f
"Invalid buffer scopes:
{
buffer
.
scope
()
}
and
{
out
.
scope
()
}
"
)
return
reduce_macro
(
buffer
,
out
,
reduce_type
,
dim
,
clear
)
def
reduce_max
(
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
dim
:
int
=
-
1
,
clear
:
bool
=
True
):
def
reduce_max
(
buffer
:
tir
.
Buffer
,
out
:
tir
.
Buffer
,
dim
:
int
=
-
1
,
clear
:
bool
=
True
):
...
...
tilelang/language/symbolics.py
View file @
bbbf4207
...
@@ -7,7 +7,6 @@ from tilelang.utils import deprecated
...
@@ -7,7 +7,6 @@ from tilelang.utils import deprecated
__all__
=
[
"dynamic"
,
"symbolic"
]
__all__
=
[
"dynamic"
,
"symbolic"
]
@
deprecated
(
"T.dynamic(...)"
,
"tir.Var(...)"
,
"v0.1.9"
)
def
dynamic
(
name
:
str
,
dtype
:
str
=
"int32"
):
def
dynamic
(
name
:
str
,
dtype
:
str
=
"int32"
):
"""
"""
Create a TIR dynamic symbolic variable.
Create a TIR dynamic symbolic variable.
...
@@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"):
...
@@ -22,7 +21,7 @@ def dynamic(name: str, dtype: str = "int32"):
return
tir
.
Var
(
name
,
dtype
)
return
tir
.
Var
(
name
,
dtype
)
@
deprecated
(
"T.symbolic(...)"
,
"T.dynamic(...)"
)
@
deprecated
(
"T.symbolic(...)"
,
"T.dynamic(...)"
,
"v0.1.9"
)
def
symbolic
(
name
:
str
,
dtype
:
str
=
"int32"
):
def
symbolic
(
name
:
str
,
dtype
:
str
=
"int32"
):
"""Deprecated alias for `T.dynamic`."""
"""Deprecated alias for `T.dynamic`."""
return
tir
.
Var
(
name
,
dtype
)
return
tir
.
Var
(
name
,
dtype
)
Prev
1
…
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment