Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
562 additions
and
560 deletions
+562
-560
tilelang/jit/__init__.py
tilelang/jit/__init__.py
+76
-84
tilelang/jit/adapter/base.py
tilelang/jit/adapter/base.py
+3
-7
tilelang/jit/adapter/ctypes/adapter.py
tilelang/jit/adapter/ctypes/adapter.py
+32
-31
tilelang/jit/adapter/cython/adapter.py
tilelang/jit/adapter/cython/adapter.py
+38
-41
tilelang/jit/adapter/libgen.py
tilelang/jit/adapter/libgen.py
+8
-11
tilelang/jit/adapter/nvrtc/__init__.py
tilelang/jit/adapter/nvrtc/__init__.py
+7
-7
tilelang/jit/adapter/nvrtc/adapter.py
tilelang/jit/adapter/nvrtc/adapter.py
+27
-25
tilelang/jit/adapter/nvrtc/libgen.py
tilelang/jit/adapter/nvrtc/libgen.py
+9
-11
tilelang/jit/adapter/nvrtc/wrapper.py
tilelang/jit/adapter/nvrtc/wrapper.py
+82
-64
tilelang/jit/adapter/torch/__init__.py
tilelang/jit/adapter/torch/__init__.py
+1
-1
tilelang/jit/adapter/torch/metal.py
tilelang/jit/adapter/torch/metal.py
+4
-8
tilelang/jit/adapter/tvm_ffi.py
tilelang/jit/adapter/tvm_ffi.py
+40
-40
tilelang/jit/adapter/utils.py
tilelang/jit/adapter/utils.py
+46
-66
tilelang/jit/adapter/wrapper.py
tilelang/jit/adapter/wrapper.py
+129
-97
tilelang/jit/execution_backend.py
tilelang/jit/execution_backend.py
+5
-2
tilelang/jit/kernel.py
tilelang/jit/kernel.py
+33
-38
tilelang/language/__init__.py
tilelang/language/__init__.py
+5
-1
tilelang/language/allocate.py
tilelang/language/allocate.py
+15
-26
tilelang/language/annotations.py
tilelang/language/annotations.py
+1
-0
tilelang/language/ast/__init__.py
tilelang/language/ast/__init__.py
+1
-0
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/jit/__init__.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
...
...
@@ -39,17 +40,16 @@ from tqdm.auto import tqdm
logger
=
getLogger
(
__name__
)
_P
=
ParamSpec
(
'
_P
'
)
_KP
=
ParamSpec
(
'
_KP
'
)
_T
=
TypeVar
(
'
_T
'
)
_Ret
=
TypeVar
(
'
_Ret
'
)
_P
=
ParamSpec
(
"
_P
"
)
_KP
=
ParamSpec
(
"
_KP
"
)
_T
=
TypeVar
(
"
_T
"
)
_Ret
=
TypeVar
(
"
_Ret
"
)
def
compile
(
func
:
PrimFunc
[
_KP
,
_T
]
=
None
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
...
...
@@ -83,11 +83,9 @@ def compile(
if
isinstance
(
compile_flags
,
str
):
compile_flags
=
[
compile_flags
]
if
hasattr
(
func
,
'
out_idx_override
'
):
if
hasattr
(
func
,
"
out_idx_override
"
):
if
func
.
out_idx_override
is
not
None
and
out_idx
is
not
None
:
raise
ValueError
(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
raise
ValueError
(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
out_idx
=
func
.
out_idx_override
or
out_idx
# This path is not a performance critical path, so we can afford to convert the target.
...
...
@@ -96,6 +94,7 @@ def compile(
# Resolve execution backend (handles aliases, auto, validation per target)
requested_backend
=
execution_backend
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
execution_backend
=
resolve_execution_backend
(
requested_backend
,
target
)
if
verbose
:
allowed_now
=
allowed_backends_for_target
(
target
,
include_unavailable
=
False
)
...
...
@@ -119,17 +118,18 @@ def compile(
)
def
par_compile
(
funcs
:
Iterable
[
PrimFunc
[
_KP
,
_T
]],
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
def
par_compile
(
funcs
:
Iterable
[
PrimFunc
[
_KP
,
_T
]],
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
,
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
...
...
@@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
with
concurrent
.
futures
.
ThreadPoolExecutor
(
num_workers
,
'
tl-par-comp
'
)
as
executor
:
with
concurrent
.
futures
.
ThreadPoolExecutor
(
num_workers
,
"
tl-par-comp
"
)
as
executor
:
futures
=
[]
future_map
=
{}
for
i
,
func
in
enumerate
(
funcs
):
...
...
@@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
futures
.
append
(
future
)
results
=
[...
for
_
in
futures
]
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Parallel Compiling"
,
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Parallel Compiling"
,
):
idx
=
future_map
[
future
]
if
ignore_error
:
...
...
@@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
@
dataclass
class
JITImpl
(
Generic
[
_P
,
_KP
,
_T
,
_Ret
]):
'''
"""
Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the
...
...
@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
PrimFunc and the resulting set is compiled in parallel via the
module-level `par_compile` helper. Returns a list of JITKernel objects
in the same order as the provided configs.
'''
"""
out_idx
:
list
[
int
]
|
int
|
None
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
...
...
@@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
assert
isinstance
(
tir
,
PrimFunc
),
f
"target function must be a PrimFunc but got
{
type
(
tir
)
}
"
return
tir
def
par_compile
(
self
,
configs
:
Iterable
[
dict
[
str
,
Any
]
|
tuple
[
str
,
Any
]],
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
def
par_compile
(
self
,
configs
:
Iterable
[
dict
[
str
,
Any
]
|
tuple
[
str
,
Any
]],
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
...
...
@@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
"""
configs
=
list
(
configs
)
funcs
=
[]
for
cfg
in
tqdm
(
configs
,
desc
=
'
Elaborating
'
):
for
cfg
in
tqdm
(
configs
,
desc
=
"
Elaborating
"
):
if
isinstance
(
cfg
,
tuple
):
funcs
.
append
(
self
.
get_tir
(
*
cfg
))
elif
isinstance
(
cfg
,
dict
):
...
...
@@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
pass_configs
=
self
.
pass_configs
,
compile_flags
=
self
.
compile_flags
,
num_workers
=
num_workers
,
ignore_error
=
ignore_error
)
ignore_error
=
ignore_error
,
)
def
compile
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
func
=
self
.
get_tir
(
*
args
,
**
kwargs
)
...
...
@@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
if
self
.
debug_root_path
:
if
isinstance
(
self
.
func
,
PrimFunc
):
func_name
=
self
.
func
.
attrs
[
'
global_symbol
'
]
func_name
=
self
.
func
.
attrs
[
"
global_symbol
"
]
else
:
func_name
=
getattr
(
self
.
func
,
'
__name__
'
,
'
jit_kernel
'
)
kernel_file
=
f
'
tilelang_jit_kernel_
{
func_name
}
.c
'
program_file
=
f
'
tilelang_jit_program_
{
func_name
}
.py
'
func_name
=
getattr
(
self
.
func
,
"
__name__
"
,
"
jit_kernel
"
)
kernel_file
=
f
"
tilelang_jit_kernel_
{
func_name
}
.c
"
program_file
=
f
"
tilelang_jit_program_
{
func_name
}
.py
"
makedirs
(
self
.
debug_root_path
,
exist_ok
=
True
)
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
'w'
)
as
f
:
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
"w"
)
as
f
:
print
(
kernel_result
.
get_kernel_source
(),
file
=
f
)
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
'w'
)
as
f
:
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
"w"
)
as
f
:
print
(
func
.
script
(),
file
=
f
)
return
kernel_result
def
parse_cache_key
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
return
self
.
func
.
func_annot
.
parse_key
(
*
args
,
**
kwargs
,
**
tune_params
)
else
:
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
key_args_tuple
=
args
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
tuned_key_kwargs_tuple
=
tuple
(
sorted
(
tune_params
.
items
()))
...
...
@@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
def
convert_kernel_args
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
return
self
.
func
.
func_annot
.
convert_to_kernel_args
(
*
args
,
**
kwargs
,
**
tune_params
)
else
:
raise
NotImplementedError
(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater."
)
raise
NotImplementedError
(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater."
)
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
# Separate out the tuning parameters from the user's kwargs
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments
=
kwargs
.
pop
(
'
__return_compile_arguments
'
,
False
)
return_compile_arguments
=
kwargs
.
pop
(
"
__return_compile_arguments
"
,
False
)
if
return_compile_arguments
:
logger
.
warning
(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
logger
.
warning
(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
compile_args
=
{
'
out_idx
'
:
self
.
out_idx
,
'
execution_backend
'
:
self
.
execution_backend
,
'
target
'
:
self
.
target
,
'
target_host
'
:
self
.
target_host
,
'
verbose
'
:
self
.
verbose
,
'
pass_configs
'
:
self
.
pass_configs
,
'
compile_flags
'
:
self
.
compile_flags
,
"
out_idx
"
:
self
.
out_idx
,
"
execution_backend
"
:
self
.
execution_backend
,
"
target
"
:
self
.
target
,
"
target_host
"
:
self
.
target_host
,
"
verbose
"
:
self
.
verbose
,
"
pass_configs
"
:
self
.
pass_configs
,
"
compile_flags
"
:
self
.
compile_flags
,
}
return
compile_args
key
=
self
.
parse_cache_key
(
*
args
,
**
kwargs
)
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
kernel
=
self
.
_kernel_cache
.
get
(
key
,
None
)
if
kernel
is
None
:
...
...
@@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr
@
overload
def
jit
(
func
:
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]])
->
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]:
...
def
jit
(
func
:
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]])
->
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]:
...
@
overload
...
...
@@ -448,22 +444,22 @@ def jit(
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
)
->
Callable
[[
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]]],
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]]:
...
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
)
->
Callable
[[
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]]],
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]]:
...
def
jit
(
# This is the new public interface
func
:
Callable
[
_P
,
_T
]
|
PrimFunc
|
None
=
None
,
*
,
# Indicates subsequent arguments are keyword-only
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
ExecutionBackend
=
"auto"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
):
func
:
Callable
[
_P
,
_T
]
|
PrimFunc
|
None
=
None
,
*
,
# Indicates subsequent arguments are keyword-only
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
ExecutionBackend
=
"auto"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
...
...
@@ -516,7 +512,8 @@ def jit( # This is the new public interface
compile_flags
=
compile_flags
,
func_source
=
inspect
.
getsource
(
orig_func
),
signature
=
inspect
.
signature
(
orig_func
),
lazy_jit
=
False
)
lazy_jit
=
False
,
)
if
func
is
not
None
:
return
decorator
(
func
)
...
...
@@ -525,8 +522,7 @@ def jit( # This is the new public interface
@
overload
def
lazy_jit
(
func
:
Callable
[
_KP
,
_T
])
->
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]:
...
def
lazy_jit
(
func
:
Callable
[
_KP
,
_T
])
->
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]:
...
@
overload
...
...
@@ -539,9 +535,8 @@ def lazy_jit(
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
)
->
Callable
[[
Callable
[
_KP
,
_T
]],
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]]:
...
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
)
->
Callable
[[
Callable
[
_KP
,
_T
]],
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]]:
...
def
lazy_jit
(
...
...
@@ -555,7 +550,6 @@ def lazy_jit(
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
):
if
isinstance
(
compile_flags
,
str
):
compile_flags
=
[
compile_flags
]
...
...
@@ -567,7 +561,8 @@ def lazy_jit(
verbose
=
verbose
,
pass_configs
=
pass_configs
,
debug_root_path
=
debug_root_path
,
compile_flags
=
compile_flags
)
compile_flags
=
compile_flags
,
)
def
decorator
(
func
:
Callable
[
_P
,
_T
]):
pf
:
PrimFunc
[
_P
,
_T
]
|
PrimFuncCreater
[
_P
,
_T
]
=
prim_func
(
func
,
generator
=
True
)
...
...
@@ -576,10 +571,7 @@ def lazy_jit(
# return compile(pf, **compile_args)
# else:
return
JITImpl
(
func
=
pf
,
**
compile_args
,
func_source
=
inspect
.
getsource
(
pf
.
orig_func
),
signature
=
inspect
.
signature
(
pf
.
orig_func
),
lazy_jit
=
True
)
func
=
pf
,
**
compile_args
,
func_source
=
inspect
.
getsource
(
pf
.
orig_func
),
signature
=
inspect
.
signature
(
pf
.
orig_func
),
lazy_jit
=
True
)
return
decorator
(
func
)
if
func
is
not
None
else
decorator
tilelang/jit/adapter/base.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
...
...
@@ -8,7 +9,6 @@ import torch
class
BaseKernelAdapter
(
ABC
):
func
:
Callable
|
None
=
None
def
__init__
(
self
,
mod
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
])
->
None
:
...
...
@@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC):
result_idx
=
[]
elif
isinstance
(
result_idx
,
int
):
if
result_idx
>
len
(
params
)
or
result_idx
<
-
len
(
params
):
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
if
result_idx
<
0
:
result_idx
=
len
(
params
)
+
result_idx
result_idx
=
[
result_idx
]
elif
isinstance
(
result_idx
,
list
):
for
i
,
idx
in
enumerate
(
result_idx
):
if
idx
>=
len
(
params
)
or
idx
<
-
len
(
params
):
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
if
idx
<
0
:
result_idx
[
i
]
=
len
(
params
)
+
idx
else
:
...
...
tilelang/jit/adapter/ctypes/adapter.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
import
torch
from
..base
import
BaseKernelAdapter
...
...
@@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter):
param_dtypes
:
list
[
torch
.
dtype
]
|
None
=
None
# Cache for parameter dtypes
param_shapes
:
list
[
list
]
|
None
=
None
# Cache for parameter shapes
def
__init__
(
self
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
Args:
...
...
@@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
@
classmethod
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
@@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
)):
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
)):
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
...
...
@@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
self
.
lib
.
call
(
*
ctypes_args
)
...
...
@@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
@
property
def
is_dynamic
(
self
):
"""Indicates whether the kernel handles dynamic shapes."""
return
(
self
.
dynamic_symbolic_map
is
not
None
and
len
(
self
.
dynamic_symbolic_map
)
>
0
)
return
self
.
dynamic_symbolic_map
is
not
None
and
len
(
self
.
dynamic_symbolic_map
)
>
0
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
False
):
"""Returns the source code of the compiled kernel."""
...
...
tilelang/jit/adapter/cython/adapter.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
import
ctypes
import
logging
...
...
@@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
# Pass configs for the compiler
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
Args:
...
...
@@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
result
=
self
.
lib
.
init
()
if
result
!=
0
:
error_msg
=
self
.
lib
.
get_last_error
().
decode
(
'
utf-8
'
)
error_msg
=
self
.
lib
.
get_last_error
().
decode
(
"
utf-8
"
)
error_msg
+=
f
"
\n
{
self
.
lib_code
}
"
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
...
...
@@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
@
classmethod
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
@@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
result
=
adapter
.
lib
.
init
()
if
result
!=
0
:
error_msg
=
adapter
.
lib
.
get_last_error
().
decode
(
'
utf-8
'
)
error_msg
=
adapter
.
lib
.
get_last_error
().
decode
(
"
utf-8
"
)
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
adapter
.
cython_wrapper
=
CythonKernelWrapper
(
adapter
.
result_idx
,
adapter
.
params
,
adapter
.
lib
)
adapter
.
cython_wrapper
=
CythonKernelWrapper
(
adapter
.
result_idx
,
adapter
.
params
,
adapter
.
lib
)
adapter
.
cython_wrapper
.
set_dynamic_symbolic_map
(
adapter
.
dynamic_symbolic_map
)
adapter
.
cython_wrapper
.
set_buffer_dtype_map
(
adapter
.
buffer_dtype_map
)
adapter
.
cython_wrapper
.
set_static_shape_map
(
adapter
.
static_shape_map
)
...
...
@@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
)):
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
)):
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
...
...
@@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
params
=
func
.
params
ptr_map
=
{}
for
i
,
param
in
enumerate
(
params
):
if
param
.
dtype
==
'
handle
'
:
if
param
.
dtype
==
"
handle
"
:
ptr_map
[
i
]
=
param
.
name
return
ptr_map
def
_process_static_buffer_infos
(
self
)
->
\
tuple
[
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
list
[
tuple
[
tir
.
Var
]]]:
def
_process_static_buffer_infos
(
self
,
)
->
tuple
[
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
list
[
tuple
[
tir
.
Var
]]]:
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
...
...
@@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
self
.
lib
.
call
(
*
ctypes_args
)
...
...
@@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
skip_tensor_validation: Whether to skip tensor attributes validation which
includes shape, dtype, device, etc.
"""
return
self
.
cython_wrapper
.
forward
([
*
args
],
stream
=
stream
,
skip_tensor_validation
=
skip_tensor_validation
)
return
self
.
cython_wrapper
.
forward
([
*
args
],
stream
=
stream
,
skip_tensor_validation
=
skip_tensor_validation
)
return
lambda_forward
...
...
tilelang/jit/adapter/libgen.py
View file @
29051439
...
...
@@ -55,6 +55,7 @@ class LibraryGenerator:
verbose
=
self
.
verbose
if
is_cuda_target
(
target
):
from
tilelang.env
import
CUTLASS_INCLUDE_DIR
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
# noqa: SIM115
target_arch
=
get_target_arch
(
get_target_compute_version
(
target
))
libpath
=
src
.
name
.
replace
(
".cu"
,
".so"
)
...
...
@@ -65,15 +66,12 @@ class LibraryGenerator:
"TL_ENABLE_FAST_MATH"
,
"0.1.7"
,
)
enable_fast_math
=
not
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_DISABLE_FAST_MATH
,
True
)
enable_fast_math
=
not
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_DISABLE_FAST_MATH
,
True
)
else
:
enable_fast_math
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_FAST_MATH
,
False
)
ptxas_usage_level
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_PTXAS_REGISTER_USAGE_LEVEL
,
None
)
verbose_ptxas_output
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_PTXAS_VERBOSE_OUTPUT
,
False
)
ptxas_usage_level
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_PTXAS_REGISTER_USAGE_LEVEL
,
None
)
verbose_ptxas_output
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_PTXAS_VERBOSE_OUTPUT
,
False
)
command
=
[
get_nvcc_compiler
(),
...
...
@@ -102,6 +100,7 @@ class LibraryGenerator:
elif
is_hip_target
(
target
):
from
tilelang.env
import
COMPOSABLE_KERNEL_INCLUDE_DIR
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
rocm_path
=
find_rocm_path
()
...
...
@@ -119,6 +118,7 @@ class LibraryGenerator:
]
elif
is_cpu_target
(
target
):
from
tilelang.contrib.cc
import
get_cplus_compiler
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
...
...
@@ -134,9 +134,7 @@ class LibraryGenerator:
]
if
self
.
compile_flags
:
command
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
command
]
command
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
command
]
command
+=
[
"-o"
,
libpath
]
...
...
@@ -151,8 +149,7 @@ class LibraryGenerator:
raise
RuntimeError
(
f
"Compile kernel failed because of
{
e
}
"
)
from
e
if
ret
.
returncode
!=
0
:
raise
RuntimeError
(
f
"Compilation Failed!
{
command
}
"
f
"
\n
{
self
.
lib_code
}
"
)
raise
RuntimeError
(
f
"Compilation Failed!
{
command
}
\n
{
self
.
lib_code
}
"
)
self
.
srcpath
=
src
.
name
self
.
libpath
=
libpath
...
...
tilelang/jit/adapter/nvrtc/__init__.py
View file @
29051439
...
...
@@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
import
logging
__all__
=
[
'NVRTCKernelAdapter'
,
'TLNVRTCSourceWrapper'
,
'NVRTCLibraryGenerator'
,
'is_nvrtc_available'
,
'check_nvrtc_available'
]
__all__
=
[
"NVRTCKernelAdapter"
,
"TLNVRTCSourceWrapper"
,
"NVRTCLibraryGenerator"
,
"is_nvrtc_available"
,
"check_nvrtc_available"
]
logger
=
logging
.
getLogger
(
__name__
)
# Check if cuda-python is available
is_nvrtc_available
=
False
NVRTC_UNAVAILABLE_MESSAGE
=
(
"cuda-python is not available, NVRTC backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend."
)
NVRTC_UNAVAILABLE_MESSAGE
=
(
"cuda-python is not available, NVRTC backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend."
)
try
:
import
cuda.bindings.driver
as
cuda
# noqa: F401
import
cuda.bindings.nvrtc
as
nvrtc
# noqa: F401
is_nvrtc_available
=
True
except
ImportError
as
e
:
logger
.
debug
(
f
"cuda-python import failed:
{
e
}
"
)
...
...
tilelang/jit/adapter/nvrtc/adapter.py
View file @
29051439
...
...
@@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
pymodule
=
None
kernels
=
{}
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
check_nvrtc_available
()
self
.
params
=
params
...
...
@@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
@
classmethod
def
from_database
(
cls
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
@@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
return
self
.
host_func
def
_forward_from_prebuild_lib
(
self
,
*
args
,
stream
:
int
|
None
=
None
):
"""Low-level function to call the compiled CUDA kernel.
"""
"""Low-level function to call the compiled CUDA kernel."""
return
self
.
pymodule
.
call
(
self
.
kernels
,
*
args
,
stream
=
stream
)
def
_wrap_forward_from_prebuild_lib
(
self
,
*
ins
:
list
[
torch
.
Tensor
],
stream
:
int
|
None
=
None
):
...
...
tilelang/jit/adapter/nvrtc/libgen.py
View file @
29051439
...
...
@@ -13,6 +13,7 @@ Key responsibilities:
- Load compiled cubin and extract kernel handles
- Manage library lifecycle (load/unload)
"""
from
__future__
import
annotations
import
importlib
import
logging
...
...
@@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
culib: CUDA library handle (CUlibrary)
pymodule: Imported Python module containing call() function
"""
host_func
:
str
|
None
=
None
culib
:
cuda
.
CUlibrary
|
None
=
None
pymodule
:
ModuleType
|
None
=
None
...
...
@@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator):
ctx
=
cuda
.
cuCtxGetCurrent
()[
1
]
if
cuda
.
cuCtxGetApiVersion
(
ctx
)[
0
]
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
import
torch
torch
.
cuda
.
synchronize
()
result
,
self
.
culib
=
cuda
.
cuLibraryLoadFromFile
(
bytes
(
lib_path
,
"utf-8"
),
[],
[],
0
,
[],
[],
0
)
result
,
self
.
culib
=
cuda
.
cuLibraryLoadFromFile
(
bytes
(
lib_path
,
"utf-8"
),
[],
[],
0
,
[],
[],
0
)
if
result
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
raise
RuntimeError
(
f
"Failed to load library:
{
lib_path
}
, error:
{
result
}
"
)
...
...
@@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator):
target
=
self
.
target
verbose
=
self
.
verbose
if
is_cuda_target
(
target
):
from
tilelang.env
import
(
CUDA_HOME
,
CUTLASS_INCLUDE_DIR
,
TILELANG_TEMPLATE_PATH
)
from
tilelang.env
import
CUDA_HOME
,
CUTLASS_INCLUDE_DIR
,
TILELANG_TEMPLATE_PATH
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
libpath
=
src
.
name
.
replace
(
".cu"
,
".cubin"
)
...
...
@@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator):
f
"-D__CUDACC_VER_MAJOR__=
{
__CUDACC_VER_MAJOR__
}
"
,
]
if
self
.
compile_flags
:
options
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
options
]
options
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
options
]
cubin_bytes
=
compile_cuda
(
self
.
lib_code
,
target_format
=
"cubin"
,
options
=
options
,
verbose
=
verbose
)
cubin_bytes
=
compile_cuda
(
self
.
lib_code
,
target_format
=
"cubin"
,
options
=
options
,
verbose
=
verbose
)
with
open
(
libpath
,
"wb"
)
as
f
:
f
.
write
(
cubin_bytes
)
...
...
@@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
self
.
libpath
=
libpath
self
.
pypath
=
src
.
name
.
replace
(
".cu"
,
".py"
)
if
self
.
host_func
is
None
:
raise
RuntimeError
(
"Host function is not set, please call update_host_func() first."
)
raise
RuntimeError
(
"Host function is not set, please call update_host_func() first."
)
with
open
(
self
.
pypath
,
"w"
)
as
f
:
f
.
write
(
self
.
host_func
)
else
:
...
...
tilelang/jit/adapter/nvrtc/wrapper.py
View file @
29051439
...
...
@@ -12,6 +12,7 @@ Key design:
- Dict-based deduplication ensures TMA descriptors created only once
- Generates pure Python using cuda.bindings.driver for zero C++ dependency
"""
from
__future__
import
annotations
from
typing
import
Any
,
ClassVar
...
...
@@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit
from
tilelang
import
tvm
as
tvm
from
tilelang.jit.adapter.wrapper
import
TLCUDASourceWrapper
from
tilelang.jit.adapter.utils
import
(
match_declare_kernel
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
)
from
tilelang.jit.adapter.utils
import
match_declare_kernel
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
PREDEF_HOST_FUNC_PY
=
"""
from cuda.bindings.driver import (
...
...
@@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
_generated_host_func
:
str
|
None
=
None
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
"""Initialize NVRTC wrapper with compiled IR modules.
Args:
...
...
@@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
"name"
:
buffer
.
data
.
name
,
"type"
:
"ctypes.c_void_p"
,
})
function_args
.
append
(
{
"name"
:
buffer
.
data
.
name
,
"type"
:
"ctypes.c_void_p"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
...
...
@@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return
(
f
"
{
name
}
.data_ptr()"
,
arg_type
)
return
(
name
,
arg_type
)
call_args
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
,
transform_nvrtc_arg
)
call_args
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
,
transform_nvrtc_arg
)
for
arg_name
,
arg_type
in
call_args
:
if
arg_type
==
"ctypes.c_void_p"
:
...
...
@@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
break
# Store kernel info for second pass
kernel_info_list
.
append
({
'function_name'
:
function_name
,
'block_info'
:
block_info
,
'grid_info'
:
grid_info
,
'dynamic_smem_buf'
:
dynamic_smem_buf
,
'call_args'
:
call_args
,
'device_index'
:
device_index
,
})
kernel_info_list
.
append
(
{
"function_name"
:
function_name
,
"block_info"
:
block_info
,
"grid_info"
:
grid_info
,
"dynamic_smem_buf"
:
dynamic_smem_buf
,
"call_args"
:
call_args
,
"device_index"
:
device_index
,
}
)
# Generate TMA descriptor initialization code once for all kernels
kernel_launch_code
+=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
# Second pass: generate kernel launch code for each kernel
for
kernel_info
in
kernel_info_list
:
function_name
=
kernel_info
[
'
function_name
'
]
block_info
=
kernel_info
[
'
block_info
'
]
grid_info
=
kernel_info
[
'
grid_info
'
]
dynamic_smem_buf
=
kernel_info
[
'
dynamic_smem_buf
'
]
call_args
=
kernel_info
[
'
call_args
'
]
device_index
=
kernel_info
[
'
device_index
'
]
function_name
=
kernel_info
[
"
function_name
"
]
block_info
=
kernel_info
[
"
block_info
"
]
grid_info
=
kernel_info
[
"
grid_info
"
]
dynamic_smem_buf
=
kernel_info
[
"
dynamic_smem_buf
"
]
call_args
=
kernel_info
[
"
call_args
"
]
device_index
=
kernel_info
[
"
device_index
"
]
arg_names
=
", "
.
join
([
arg
[
0
]
for
arg
in
call_args
])
arg_types
=
", "
.
join
([
arg
[
1
]
for
arg
in
call_args
])
...
...
@@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
kernel_launch_code
+=
init_l2_persistent_map
# Generate kernel launch code
kernel_launch_code
+=
KERNEL_LAUNCH_FUNC_PY
.
format
(
function_name
,
self
.
_pythonic_expr
(
grid_info
[
0
]),
self
.
_pythonic_expr
(
grid_info
[
1
]),
self
.
_pythonic_expr
(
grid_info
[
2
]),
self
.
_pythonic_expr
(
block_info
[
0
]),
self
.
_pythonic_expr
(
block_info
[
1
]),
self
.
_pythonic_expr
(
block_info
[
2
]),
smem_str
,
arg_names
,
arg_types
,
device_index
)
kernel_launch_code
+=
KERNEL_LAUNCH_FUNC_PY
.
format
(
function_name
,
self
.
_pythonic_expr
(
grid_info
[
0
]),
self
.
_pythonic_expr
(
grid_info
[
1
]),
self
.
_pythonic_expr
(
grid_info
[
2
]),
self
.
_pythonic_expr
(
block_info
[
0
]),
self
.
_pythonic_expr
(
block_info
[
1
]),
self
.
_pythonic_expr
(
block_info
[
2
]),
smem_str
,
arg_names
,
arg_types
,
device_index
,
)
# Reset L2 persistent map after all kernel execution
if
has_l2_persistent_map
:
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE_PY
# Wrap the kernel dispatch logic in an external C function
host_func
=
PREDEF_HOST_FUNC_PY
.
format
(
repr
(
list
(
function_informations
.
keys
())),
def_args
,
kernel_launch_code
)
host_func
=
PREDEF_HOST_FUNC_PY
.
format
(
repr
(
list
(
function_informations
.
keys
())),
def_args
,
kernel_launch_code
)
return
host_func
def
generate_l2_persistent_map
(
self
,
function_name
:
str
)
->
str
:
...
...
@@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
if
function_name
not
in
self
.
l2_persistent_map
:
return
""
init_l2_persistent_map
=
""
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
# Get persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
try
:
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
except
TypeError
:
# as size_in_bytes may be a symbolic expression
num_bytes
=
persisting_l2_cache_max_size
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC_PY
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC_PY
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
return
init_l2_persistent_map
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
"""Generate Python code to initialize TMA descriptors.
TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects
...
...
@@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return
tma_descriptor_init
# Parse TMA descriptor arguments using the common utility
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
# Generate Python code from parsed parameters
for
params
in
parsed_params
:
if
not
params
.
is_img2col
:
tma_descriptor_init
+=
TMA_DESC_INIT_FUNC_PY
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
box_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
else
:
tma_descriptor_init
+=
TMA_IM2COL_DESC_INIT_FUNC_PY
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
", "
.
join
(
params
.
lower_corner
),
", "
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
", "
.
join
(
params
.
lower_corner
),
", "
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
return
tma_descriptor_init
...
...
@@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
nonlocal
function_params
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
return
args
=
node
.
args
if
not
args
or
args
[
0
]
!=
fn
:
return
if
len
(
args
)
<
1
+
param_cnt
:
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
assert
function_params
is
not
None
,
"function_params should not be None"
...
...
tilelang/jit/adapter/torch/__init__.py
View file @
29051439
from
.metal
import
MetalKernelAdapter
__all__
=
[
'
MetalKernelAdapter
'
]
__all__
=
[
"
MetalKernelAdapter
"
]
tilelang/jit/adapter/torch/metal.py
View file @
29051439
...
...
@@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam
class
MetalKernelAdapter
(
BaseKernelAdapter
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
...
...
@@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter):
):
self
.
kernel_global_source
=
kernel_global_source
if
isinstance
(
func_or_mod
,
tir
.
PrimFunc
):
func_name
=
func_or_mod
.
attrs
[
'
global_symbol
'
]
func_name
=
func_or_mod
.
attrs
[
"
global_symbol
"
]
else
:
func_name
=
func_or_mod
.
__name__
self
.
kernel_name
=
func_name
+
'
_kernel
'
self
.
kernel_name
=
func_name
+
"
_kernel
"
self
.
verbose
=
verbose
self
.
block_info
=
[
1
,
1
,
1
]
...
...
@@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
for
var
,
func
in
device_mod
.
functions
.
items
():
assert
var
.
name_hint
==
self
.
kernel_name
thread_extent
=
func
.
attrs
[
'
thread_extent
'
]
thread_extent
=
func
.
attrs
[
"
thread_extent
"
]
for
tag
,
extent
in
thread_extent
.
items
():
if
"threadIdx"
in
tag
:
self
.
block_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
...
...
@@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self
.
grid_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
break
else
:
raise
AssertionError
(
f
'
no kernel with name
{
func_name
}
'
)
raise
AssertionError
(
f
"
no kernel with name
{
func_name
}
"
)
# print(self.block_info, self.grid_info)
super
().
__init__
(
func_or_mod
,
result_idx
=
result_idx
,
params
=
params
)
...
...
@@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter):
_kernel
=
None
def
_convert_torch_func
(
self
)
->
Callable
:
if
self
.
_kernel
is
None
:
_kernel
=
getattr
(
torch
.
mps
.
compile_shader
(
self
.
kernel_global_source
),
self
.
kernel_name
)
_threads
=
[
x
*
y
for
(
x
,
y
)
in
zip
(
self
.
block_info
,
self
.
grid_info
)]
@
wraps
(
_kernel
)
def
launcher
(
*
args
:
torch
.
Tensor
):
return
_kernel
(
*
args
,
threads
=
_threads
,
...
...
tilelang/jit/adapter/tvm_ffi.py
View file @
29051439
...
...
@@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked,
the execution observes the same stream context as the active Torch code.
On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
"""
from
__future__
import
annotations
from
typing
import
Callable
,
Any
...
...
@@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
- The stream pointer returned is a raw CUDA stream handle compatible with
TVM's device API; on CPU or when CUDA is unavailable, we return 0.
"""
# Class attributes to store compiled kernel information
target
:
str
|
Target
=
"cuda"
ir_module
:
tvm
.
IRModule
|
None
=
None
...
...
@@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map
:
dict
[
tir
.
Var
,
tuple
[
int
,
int
,
int
]]
|
None
=
None
# Stream/device functors are inherited from BaseKernelAdapter
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
rt_mod
:
tvm
.
runtime
.
Module
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
rt_mod
:
tvm
.
runtime
.
Module
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
Args:
...
...
@@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
)):
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
)):
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
...
...
@@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
# Validate input count strictly
expected_inputs
=
len
(
self
.
params
)
-
len
(
self
.
result_idx
)
if
len
(
inputs
)
!=
expected_inputs
:
raise
ValueError
(
f
"Kernel expected
{
expected_inputs
}
inputs, but
{
len
(
inputs
)
}
are provided."
)
raise
ValueError
(
f
"Kernel expected
{
expected_inputs
}
inputs, but
{
len
(
inputs
)
}
are provided."
)
# Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device.
...
...
@@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
for
s
in
param_shapes
[
i
]:
if
isinstance
(
s
,
tir
.
Var
):
for
key
in
dynamic_symbolic_map
:
if
(
str
(
s
)
==
str
(
key
)):
ref_id
,
ref_tensor_idx
,
ref_shape_idx
=
dynamic_symbolic_map
[
key
]
if
str
(
s
)
==
str
(
key
):
ref_id
,
ref_tensor_idx
,
ref_shape_idx
=
dynamic_symbolic_map
[
key
]
if
ref_id
==
2
:
shape
.
append
(
inputs
[
ref_tensor_idx
])
elif
ref_id
==
0
:
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
shape
[
ref_shape_idx
])
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
shape
[
ref_shape_idx
])
elif
ref_id
==
1
:
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
stride
()[
ref_shape_idx
])
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
stride
()[
ref_shape_idx
])
else
:
# Already converted to Python int during initialization
shape
.
append
(
s
)
...
...
@@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
out_device
=
current_device_functor
()
if
len
(
shape
)
==
0
:
param_name
=
self
.
params
[
i
].
name
if
hasattr
(
self
.
params
[
i
],
'name'
)
else
f
'parameter_
{
i
}
'
param_name
=
self
.
params
[
i
].
name
if
hasattr
(
self
.
params
[
i
],
"name"
)
else
f
"parameter_
{
i
}
"
raise
ValueError
(
f
"Cannot create output tensor (name=
{
param_name
}
) - 0-dimensional tensors are not supported. "
f
"Expected shape:
{
shape
}
"
)
f
"Expected shape:
{
shape
}
"
)
tensor
=
torch
.
empty
(
*
shape
,
dtype
=
dtype
,
device
=
out_device
)
else
:
tensor
=
inputs
[
ins_idx
]
...
...
@@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
return
func
@
classmethod
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
tilelang/jit/adapter/utils.py
View file @
29051439
...
...
@@ -70,7 +70,6 @@ def get_annotated_mod(
target_host
:
str
|
Target
|
None
=
None
,
model_type
:
Literal
[
"device"
,
"host"
,
"all"
]
=
"all"
,
)
->
IRModule
|
tuple
[
IRModule
,
IRModule
]:
# Validate model_type early
if
model_type
not
in
{
"device"
,
"host"
,
"all"
}:
raise
ValueError
(
f
"Invalid model type:
{
model_type
}
"
)
...
...
@@ -95,21 +94,15 @@ def get_annotated_mod(
# Define dispatch dictionary for different model types
dispatch
=
{
"device"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
"host"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_host_call
)(
m
),
"all"
:
lambda
m
:
(
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
tir
.
transform
.
Filter
(
_is_host_call
)
(
m
)),
"device"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
"host"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_host_call
)(
m
),
"all"
:
lambda
m
:
(
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
tir
.
transform
.
Filter
(
_is_host_call
)(
m
)),
}
return
dispatch
[
model_type
](
mod
)
def
pythonic_expr
(
expr
:
tvm
.
tir
.
PrimExpr
,
dtype_map
:
dict
[
str
,
str
]
|
None
=
None
,
ignore_cast
:
bool
=
False
)
->
str
:
def
pythonic_expr
(
expr
:
tvm
.
tir
.
PrimExpr
,
dtype_map
:
dict
[
str
,
str
]
|
None
=
None
,
ignore_cast
:
bool
=
False
)
->
str
:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
...
...
@@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
s
=
f
"(
{
type_str
}
)
{
value_str
}
"
p
=
PRECEDENCE
.
get
(
type
(
node
),
ATOMIC_PRECEDENCE
)
elif
isinstance
(
node
,
(
tvm
.
tir
.
Mul
,
tvm
.
tir
.
FloorDiv
,
tvm
.
tir
.
Add
,
tvm
.
tir
.
Sub
,
tvm
.
tir
.
FloorMod
,
tvm
.
tir
.
LT
,
tvm
.
tir
.
LE
,
tvm
.
tir
.
GT
,
tvm
.
tir
.
GE
,
tvm
.
tir
.
EQ
,
tvm
.
tir
.
NE
,
tvm
.
tir
.
And
,
tvm
.
tir
.
Or
)):
node
,
(
tvm
.
tir
.
Mul
,
tvm
.
tir
.
FloorDiv
,
tvm
.
tir
.
Add
,
tvm
.
tir
.
Sub
,
tvm
.
tir
.
FloorMod
,
tvm
.
tir
.
LT
,
tvm
.
tir
.
LE
,
tvm
.
tir
.
GT
,
tvm
.
tir
.
GE
,
tvm
.
tir
.
EQ
,
tvm
.
tir
.
NE
,
tvm
.
tir
.
And
,
tvm
.
tir
.
Or
,
),
):
op_map
=
{
tvm
.
tir
.
Mul
:
"*"
,
tvm
.
tir
.
FloorDiv
:
"/"
,
...
...
@@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
return
next
(
iter
(
node_to_result_map
[
expr
]),
""
)
def
maybe_desc_name
(
name
:
str
,
matches
:
list
[
str
],
i
:
int
,
desc_name_map
:
dict
[
str
,
str
]
|
None
=
None
)
->
bool
:
def
maybe_desc_name
(
name
:
str
,
matches
:
list
[
str
],
i
:
int
,
desc_name_map
:
dict
[
str
,
str
]
|
None
=
None
)
->
bool
:
"""
Check if a parameter name corresponds to a TMA descriptor.
...
...
@@ -290,8 +294,7 @@ def parse_function_call_args(
else
:
call_args
.
append
(
match
)
if
desc_name_var_map
is
not
None
and
function_params
is
not
None
:
assert
len
(
call_args
)
<=
len
(
function_params
),
\
f
"Too many arguments:
{
len
(
call_args
)
}
>
{
len
(
function_params
)
}
"
assert
len
(
call_args
)
<=
len
(
function_params
),
f
"Too many arguments:
{
len
(
call_args
)
}
>
{
len
(
function_params
)
}
"
desc_name_var_map
[
match
]
=
function_params
[
len
(
call_args
)
-
1
]
return
call_args
...
...
@@ -300,12 +303,7 @@ def parse_function_call_args(
class
TMADescriptorParams
:
"""Parsed TMA descriptor parameters."""
def
__init__
(
self
,
handle_name
:
str
,
dtype
:
str
,
tensor_rank
:
int
,
global_address
:
Any
,
is_img2col
:
bool
=
False
):
def
__init__
(
self
,
handle_name
:
str
,
dtype
:
str
,
tensor_rank
:
int
,
global_address
:
Any
,
is_img2col
:
bool
=
False
):
self
.
handle_name
=
handle_name
self
.
dtype
=
dtype
self
.
tensor_rank
=
tensor_rank
...
...
@@ -355,22 +353,19 @@ def parse_tma_descriptor_args(
results
=
[]
for
handle_name
,
_
in
desc_name_map
.
items
():
assert
handle_name
in
desc_name_var_map
,
\
f
"Handle name
{
handle_name
}
not found in desc_name_var_map"
assert
handle_name
in
desc_name_var_map
,
f
"Handle name
{
handle_name
}
not found in desc_name_var_map"
desc_var
=
desc_name_var_map
[
handle_name
]
assert
desc_var
in
tma_descriptor_args
,
\
f
"TMA descriptor
{
desc_var
}
not found in
{
tma_descriptor_args
}
"
assert
desc_var
in
tma_descriptor_args
,
f
"TMA descriptor
{
desc_var
}
not found in
{
tma_descriptor_args
}
"
args
=
tma_descriptor_args
[
desc_var
]
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
if
len
(
args
)
<
3
:
raise
ValueError
(
f
"TMA descriptor args too short:
{
len
(
args
)
}
elements, expected at least 3"
)
raise
ValueError
(
f
"TMA descriptor args too short:
{
len
(
args
)
}
elements, expected at least 3"
)
tma_create_str
,
_
,
dtype
,
tensor_rank
,
global_address
,
*
remaining_args
=
args
is_img2col
=
(
tma_create_str
.
value
==
"__tvm_tensormap_create_im2col"
)
is_img2col
=
tma_create_str
.
value
==
"__tvm_tensormap_create_im2col"
# Convert basic fields
dtype
=
pythonic_expr_func
(
dtype
)
...
...
@@ -386,60 +381,45 @@ def parse_tma_descriptor_args(
# Tiled mode
expected_args_len
=
4
*
tensor_rank
+
4
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
]
params
.
box_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]
]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]]
params
.
box_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]]
# Extract remaining parameters
try
:
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
4
*
tensor_rank
:
4
*
tensor_rank
+
4
]
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
4
*
tensor_rank
:
4
*
tensor_rank
+
4
]
params
.
interleave
=
pythonic_expr_func
(
interleave
)
params
.
swizzle
=
pythonic_expr_func
(
swizzle
)
params
.
l2_promotion
=
pythonic_expr_func
(
l2_promotion
)
params
.
oob_fill
=
pythonic_expr_func
(
oob_fill
)
except
ValueError
as
e
:
raise
ValueError
(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
raise
ValueError
(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
else
:
# Im2col mode
expected_args_len
=
5
*
tensor_rank
+
2
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
]
params
.
lower_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
-
2
]
]
params
.
upper_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
4
*
tensor_rank
-
2
:
5
*
tensor_rank
-
4
]
]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]]
params
.
lower_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
-
2
]]
params
.
upper_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
4
*
tensor_rank
-
2
:
5
*
tensor_rank
-
4
]]
# Extract remaining parameters
try
:
smem_box_pixel
,
smem_box_channel
,
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
\
remaining_args
[
5
*
tensor_rank
-
4
:
5
*
tensor_rank
+
2
]
smem_box_pixel
,
smem_box_channel
,
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
5
*
tensor_rank
-
4
:
5
*
tensor_rank
+
2
]
params
.
smem_box_pixel
=
pythonic_expr_func
(
smem_box_pixel
)
params
.
smem_box_channel
=
pythonic_expr_func
(
smem_box_channel
)
params
.
interleave
=
pythonic_expr_func
(
interleave
)
...
...
tilelang/jit/adapter/wrapper.py
View file @
29051439
...
...
@@ -4,9 +4,18 @@ from tilelang import tvm as tvm
from
typing
import
Any
from
tvm
import
IRModule
from
tvm.target
import
Target
from
.utils
import
(
is_metal_target
,
match_declare_kernel
,
match_declare_kernel_cpu
,
is_cuda_target
,
is_hip_target
,
is_cpu_target
,
get_annotated_mod
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
)
from
.utils
import
(
is_metal_target
,
match_declare_kernel
,
match_declare_kernel_cpu
,
is_cuda_target
,
is_hip_target
,
is_cpu_target
,
get_annotated_mod
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
,
)
import
re
import
logging
import
textwrap
...
...
@@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """
class
BaseWrapper
(
ABC
):
@
abstractmethod
def
wrap
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
...
...
@@ -163,13 +171,15 @@ class TLCUDASourceWrapper:
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
source
=
source
...
...
@@ -211,15 +221,16 @@ class TLCUDASourceWrapper:
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
"name"
:
buffer
.
data
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"* __restrict__"
,
})
function_args
.
append
(
{
"name"
:
buffer
.
data
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"* __restrict__"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
...
...
@@ -256,38 +267,40 @@ class TLCUDASourceWrapper:
# Identify the start of the function body to insert arguments
index
=
code
.
index
(
"{"
,
index
)
block_str
=
f
"dim3(
{
self
.
_pythonic_expr
(
block_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
2
])
}
)"
grid_str
=
f
"dim3(
{
self
.
_pythonic_expr
(
grid_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
2
])
}
)"
block_str
=
(
f
"dim3(
{
self
.
_pythonic_expr
(
block_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
2
])
}
)"
)
grid_str
=
(
f
"dim3(
{
self
.
_pythonic_expr
(
grid_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
2
])
}
)"
)
smem_str
=
0
if
dynamic_smem_buf
is
None
else
dynamic_smem_buf
init_l2_persistent_map
=
self
.
generate_l2_persistent_map
(
function_name
)
kernel_launch_code
+=
init_l2_persistent_map
if
self
.
use_cooperative_groups
[
function_name
]:
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
(
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
)
args_array
=
[
f
"(void*)&
{
arg
}
"
for
arg
in
args_list
]
call_args
=
f
"
\t
void*
{
function_name
}
_args[] = {{
{
', '
.
join
(
args_array
)
}
}};
\n
"
kernel_launch_code
+=
call_args
# Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code
+=
"
\t
TILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));
\n
"
.
format
(
function_name
,
grid_str
,
block_str
,
function_name
+
"_args"
,
smem_str
)
function_name
,
grid_str
,
block_str
,
function_name
+
"_args"
,
smem_str
)
else
:
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
(
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
)
call_args
=
", "
.
join
(
args_list
)
kernel_launch_code
+=
f
"
\t
{
function_name
}
<<<
{
grid_str
}
,
{
block_str
}
,
{
smem_str
}
, stream>>>(
{
call_args
}
);
\n
"
kernel_launch_code
+=
f
"
\t
TILELANG_CHECK_LAST_ERROR(
\
"
{
function_name
}
\
"
);
\n
"
kernel_launch_code
+=
f
'
\t
TILELANG_CHECK_LAST_ERROR("
{
function_name
}
");
\n
'
if
has_l2_persistent_map
:
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE
init_tma_descriptor_args
=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
init_tma_descriptor_args
=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
kernel_launch_code
=
init_tma_descriptor_args
+
kernel_launch_code
# Wrap the kernel dispatch logic in an external C function
...
...
@@ -298,46 +311,63 @@ class TLCUDASourceWrapper:
if
function_name
not
in
self
.
l2_persistent_map
:
return
""
init_l2_persistent_map
=
""
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
# get persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
try
:
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
except
Exception
:
# as size_in_bytes maybe a symbolic expression
num_bytes
=
persisting_l2_cache_max_size
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
return
init_l2_persistent_map
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
tma_descripter_init
=
""
if
self
.
tma_descriptor_args
is
None
:
return
tma_descripter_init
# Parse TMA descriptor arguments using the common utility
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
# Generate C++ code from parsed parameters
for
params
in
parsed_params
:
if
not
params
.
is_img2col
:
tma_descripter_init
+=
TMA_DESC_INIT_FUNC
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
box_dim
),
","
.
join
(
params
.
element_strides
),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
box_dim
),
","
.
join
(
params
.
element_strides
),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
else
:
tma_descripter_init
+=
TMA_IM2COL_DESC_INIT_FUNC
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
element_strides
),
","
.
join
(
params
.
lower_corner
),
","
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
element_strides
),
","
.
join
(
params
.
lower_corner
),
","
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
return
tma_descripter_init
...
...
@@ -347,9 +377,8 @@ class TLCUDASourceWrapper:
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
self
.
device_mod
=
device_mod
self
.
host_mod
=
host_mod
assert
(
len
(
self
.
device_mod
.
functions
)
>=
1
),
"Device module should have at least one function."
assert
(
len
(
self
.
host_mod
.
functions
)
==
1
),
"Only support one function in host module."
assert
len
(
self
.
device_mod
.
functions
)
>=
1
,
"Device module should have at least one function."
assert
len
(
self
.
host_mod
.
functions
)
==
1
,
"Only support one function in host module."
block_info_map
=
{}
grid_info_map
=
{}
...
...
@@ -438,8 +467,7 @@ class TLCUDASourceWrapper:
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
if
dynamic_smem_buf
is
not
None
:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
.
format
(
function_name
,
dynamic_smem_buf
)
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
.
format
(
function_name
,
dynamic_smem_buf
)
# Format the initialization function using the call_str
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
return
init_funcs
...
...
@@ -466,17 +494,14 @@ class TLCUDASourceWrapper:
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
nonlocal
function_params
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
return
args
=
node
.
args
if
not
args
or
args
[
0
]
!=
fn
:
return
if
len
(
args
)
<
1
+
param_cnt
:
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
assert
function_params
is
not
None
,
"function_params should not be None"
...
...
@@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"uchar"
:
"uint8_t"
,
}
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
super
().
__init__
(
scheduled_ir_module
,
source
,
target
,
device_mod
,
host_mod
,
pass_configs
)
def
get_init_func
(
self
):
...
...
@@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
if
dynamic_smem_buf
is
not
None
:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
.
format
(
function_name
,
dynamic_smem_buf
)
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
.
format
(
function_name
,
dynamic_smem_buf
)
# Format the initialization function using the call_str
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
return
init_funcs
...
...
@@ -623,13 +649,15 @@ class TLCPUSourceWrapper:
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
source
=
source
...
...
@@ -658,15 +686,16 @@ class TLCPUSourceWrapper:
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
"name"
:
buffer
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"*"
,
})
function_args
.
append
(
{
"name"
:
buffer
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"*"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
function_args
.
append
({
"name"
:
dyn_sym
,
"type"
:
self
.
_lookup_type
(
dyn_sym_dtype
)})
...
...
@@ -686,7 +715,6 @@ class TLCPUSourceWrapper:
_call_str
=
""""""
for
function_name
,
_
in
function_informations
.
items
():
# Find the location of the global kernel function in the code
index
=
match_declare_kernel_cpu
(
code
,
function_name
+
"("
)
...
...
@@ -706,8 +734,8 @@ class TLCPUSourceWrapper:
def
parse_source_information
(
self
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
,
config
=
self
.
pass_configs
):
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
assert
(
len
(
device_mod
.
functions
)
>=
1
)
,
"Device module should have at least one function."
assert
(
len
(
host_mod
.
functions
)
==
1
)
,
"Only support one function in host module."
assert
len
(
device_mod
.
functions
)
>=
1
,
"Device module should have at least one function."
assert
len
(
host_mod
.
functions
)
==
1
,
"Only support one function in host module."
function_names
=
[]
for
g_var
,
_
in
device_mod
.
functions
.
items
():
...
...
@@ -767,14 +795,15 @@ class TLCPUSourceWrapper:
class
TLMetalSourceWrapper
:
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
source
=
source
...
...
@@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
"""
device_mod
:
IRModule
|
None
=
None
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
...
...
@@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper):
target
=
self
.
target
,
device_mod
=
self
.
device_mod
,
host_mod
=
self
.
host_mod
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
return
wrapper
.
lib_code
class
TLPyWrapper
(
TLWrapper
):
def
__init__
(
self
,
target
:
Target
):
super
().
__init__
(
target
)
...
...
@@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if
is_cuda_target
(
self
.
target
):
from
tilelang.jit.adapter.nvrtc
import
TLNVRTCSourceWrapper
wrapper_class
=
TLNVRTCSourceWrapper
else
:
raise
ValueError
(
f
"Unsupported target for NVRTC backend:
{
self
.
target
}
"
)
...
...
@@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper):
target
=
self
.
target
,
device_mod
=
self
.
device_mod
,
host_mod
=
self
.
host_mod
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
return
wrapper
.
host_func
,
wrapper
.
function_names
tilelang/jit/execution_backend.py
View file @
29051439
...
...
@@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
# Drop NVRTC if not importable
try
:
from
tilelang.jit.adapter.nvrtc
import
is_nvrtc_available
# lazy
if
not
is_nvrtc_available
and
"nvrtc"
in
allowed
:
allowed
=
[
b
for
b
in
allowed
if
b
!=
"nvrtc"
]
except
Exception
:
...
...
@@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
if
req
not
in
allowed_all
:
raise
ValueError
(
f
"Invalid execution backend '
{
requested
}
' for target '
{
_target_kind
(
target
)
}
'. "
f
"Allowed:
{
_format_options
(
allowed_all
)
}
. Tip: use execution_backend='auto'."
)
f
"Allowed:
{
_format_options
(
allowed_all
)
}
. Tip: use execution_backend='auto'."
)
# Promote to availability-aware set for nicer errors (e.g., nvrtc not installed)
if
req
not
in
allowed_avail
:
raise
ValueError
(
f
"Execution backend '
{
requested
}
' requires extra dependencies and is not available now. "
f
"Try one of:
{
_format_options
(
allowed_avail
)
}
."
)
f
"Try one of:
{
_format_options
(
allowed_avail
)
}
."
)
return
req
tilelang/jit/kernel.py
View file @
29051439
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Generic
,
Literal
,
TypeVar
# Python 3.9 compatibility for ParamSpec
try
:
from
typing
import
ParamSpec
...
...
@@ -14,8 +15,7 @@ import tilelang
from
tilelang
import
tvm
from
tilelang
import
env
from
tilelang.engine.param
import
CompiledArtifact
,
KernelParam
from
tilelang.jit.adapter
import
(
BaseKernelAdapter
,
CtypesKernelAdapter
,
CythonKernelAdapter
,
TVMFFIKernelAdapter
,
MetalKernelAdapter
)
from
tilelang.jit.adapter
import
BaseKernelAdapter
,
CtypesKernelAdapter
,
CythonKernelAdapter
,
TVMFFIKernelAdapter
,
MetalKernelAdapter
from
tilelang.profiler
import
Profiler
,
TensorSupplyType
from
tilelang.utils.target
import
determine_target
from
tilelang.contrib
import
nvcc
as
tl_nvcc
...
...
@@ -24,8 +24,8 @@ import os
logger
=
logging
.
getLogger
(
__name__
)
_P
=
ParamSpec
(
'
_P
'
)
_T
=
TypeVar
(
'
_T
'
)
_P
=
ParamSpec
(
"
_P
"
)
_T
=
TypeVar
(
"
_T
"
)
class
JITKernel
(
Generic
[
_P
,
_T
]):
...
...
@@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]):
torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function.
"""
prim_func
:
PrimFunc
=
None
artifact
:
CompiledArtifact
=
None
adapter
:
BaseKernelAdapter
=
None
...
...
@@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]):
if
execution_backend
==
"cython"
:
from
tilelang.contrib.cc
import
get_cplus_compiler
assert
(
get_cplus_compiler
()
is
not
None
),
"Cython backend requires a C++ compiler, please install or use other backends."
assert
get_cplus_compiler
()
is
not
None
,
"Cython backend requires a C++ compiler, please install or use other backends."
if
from_database
:
return
...
...
@@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]):
"""
return
self
.
torch_function
(
*
args
,
**
kwds
)
def
_compile_and_create_adapter
(
self
,
tilelang_func
:
PrimFunc
,
out_idx
:
list
[
int
])
->
BaseKernelAdapter
:
def
_compile_and_create_adapter
(
self
,
tilelang_func
:
PrimFunc
,
out_idx
:
list
[
int
])
->
BaseKernelAdapter
:
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
...
...
@@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]):
target
=
target
,
target_host
=
target_host
,
enable_host_codegen
=
enable_host_codegen
,
enable_device_compile
=
enable_device_compile
)
enable_device_compile
=
enable_device_compile
,
)
self
.
artifact
=
artifact
...
...
@@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]):
if
execution_backend
==
"tvm_ffi"
:
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
# But we need to ensure that the runtime is enabled and the runtime module is not None.
assert
(
artifact
.
rt_mod
is
not
None
)
,
"tvm_ffi backend requires a runtime module."
assert
artifact
.
rt_mod
is
not
None
,
"tvm_ffi backend requires a runtime module."
adapter
=
TVMFFIKernelAdapter
(
params
=
artifact
.
params
,
result_idx
=
out_idx
,
...
...
@@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]):
)
elif
execution_backend
==
"nvrtc"
:
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
adapter
=
NVRTCKernelAdapter
(
params
=
artifact
.
params
,
result_idx
=
out_idx
,
...
...
@@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]):
return
adapter
def
_create_adapter_from_database
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
]
|
int
,
target
:
str
|
Target
,
func_or_mod
:
PrimFunc
|
tvm
.
runtime
.
Module
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
)
->
BaseKernelAdapter
:
def
_create_adapter_from_database
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
]
|
int
,
target
:
str
|
Target
,
func_or_mod
:
PrimFunc
|
tvm
.
runtime
.
Module
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
)
->
BaseKernelAdapter
:
target
=
self
.
target
execution_backend
=
self
.
execution_backend
...
...
@@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]):
)
elif
execution_backend
==
"nvrtc"
:
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
adapter
=
NVRTCKernelAdapter
.
from_database
(
params
=
params
,
result_idx
=
result_idx
,
...
...
@@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]):
"""
return
cls
(
func
=
tilelang_func
,
**
kwargs
)
def
get_profiler
(
self
,
tensor_supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Auto
)
->
Profiler
:
def
get_profiler
(
self
,
tensor_supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Auto
)
->
Profiler
:
"""
Creates a profiler to benchmark the compiled runtime module.
...
...
@@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]):
Profiler
A Profiler instance for benchmarking the runtime module.
"""
return
Profiler
(
self
.
params
,
self
.
out_idx
,
tensor_supply_type
).
with_default_adapter
(
self
.
adapter
)
return
Profiler
(
self
.
params
,
self
.
out_idx
,
tensor_supply_type
).
with_default_adapter
(
self
.
adapter
)
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
True
)
->
str
:
"""
...
...
@@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]):
dir_path
=
os
.
path
.
dirname
(
kernel_path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
kernel_path
,
'w'
)
as
f
:
with
open
(
kernel_path
,
"w"
)
as
f
:
f
.
write
(
self
.
get_kernel_source
())
if
host_path
is
not
None
:
dir_path
=
os
.
path
.
dirname
(
host_path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
host_path
,
'w'
)
as
f
:
with
open
(
host_path
,
"w"
)
as
f
:
f
.
write
(
self
.
get_host_source
())
except
Exception
as
e
:
logger
.
error
(
f
"Failed to export sources:
{
e
}
"
)
# Backward compatibility alias (deprecated)
def
print_source_code
(
self
,
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
,
file
:
str
|
None
=
None
)
->
None
:
def
print_source_code
(
self
,
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
,
file
:
str
|
None
=
None
)
->
None
:
"""
Deprecated: use show_source() or export_sources() instead.
...
...
@@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]):
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
"""
logger
.
warning
(
"print_source_code is deprecated; use show_source() or export_sources() instead."
)
logger
.
warning
(
"print_source_code is deprecated; use show_source() or export_sources() instead."
)
if
file
is
not
None
:
# Historical behavior wrote only kernel source when file provided
self
.
export_sources
(
kernel_path
=
file
)
else
:
self
.
show_source
(
which
=
which
)
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
ref_latency
:
float
)
->
JITKernel
:
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
ref_latency
:
float
)
->
JITKernel
:
"""
Updates the tuning results for this kernel.
...
...
@@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]):
verbose
=
self
.
verbose
# Ensure target is set so nvcc picks correct arch via Target.current()
with
self
.
target
:
return
tl_nvcc
.
get_ptx_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
return
tl_nvcc
.
get_ptx_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_ptx
(
self
)
->
None
:
"""
...
...
@@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]):
if
verbose
is
None
:
verbose
=
self
.
verbose
with
self
.
target
:
return
tl_nvcc
.
get_sass_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
return
tl_nvcc
.
get_sass_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_sass
(
self
)
->
None
:
"""
...
...
tilelang/language/__init__.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
# from .parser import *
...
...
@@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401
from
.symbolics
import
dynamic
,
symbolic
# noqa: F401
from
.annotations
import
(
# noqa: F401
use_swizzle
,
annotate_layout
,
annotate_safe_value
,
annotate_l2_hit_ratio
,
use_swizzle
,
annotate_layout
,
annotate_safe_value
,
annotate_l2_hit_ratio
,
)
...
...
tilelang/language/allocate.py
View file @
29051439
...
...
@@ -13,8 +13,10 @@ Available allocation functions:
Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
"""
from
__future__
import
annotations
from
typing
import
TypeVar
,
overload
,
Literal
,
Callable
# Python 3.9 compatibility for advanced typing features (PEP 646)
try
:
from
typing
import
TypeVarTuple
,
Unpack
# type: ignore[attr-defined]
...
...
@@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype
from
.v2.builder
import
OutTensor
from
.v2.annot
import
Tensor
,
SharedBuffer
,
LocalBuffer
,
FragmentBuffer
_Shapes
=
TypeVarTuple
(
'
_Shapes
'
)
_DType
=
TypeVar
(
'
_DType
'
)
_Shapes
=
TypeVarTuple
(
"
_Shapes
"
)
_DType
=
TypeVar
(
"
_DType
"
)
def
alloc_shared
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"shared.dyn"
)
->
SharedBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
alloc_shared
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"shared.dyn"
)
->
SharedBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a shared memory buffer for inter-thread communication.
Args:
...
...
@@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]],
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
def
alloc_local
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local"
)
->
LocalBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
alloc_local
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local"
)
->
LocalBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a local memory buffer for thread-private storage.
Args:
...
...
@@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]],
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
def
alloc_fragment
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local.fragment"
)
->
FragmentBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
alloc_fragment
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local.fragment"
)
->
FragmentBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a fragment memory buffer for specialized operations.
Args:
...
...
@@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]],
@
overload
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
'local.var'
)
->
Buffer
:
...
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
"local.var"
)
->
Buffer
:
...
@
overload
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
'local.var'
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
"local.var"
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
...
...
@@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
raise
TypeError
(
"Scope must be provided as a string in alloc_var."
)
parsed_scope
=
parsed_scope_arg
elif
len
(
args
)
>
2
:
raise
TypeError
(
f
"alloc_var expected at most 3 positional arguments but got
{
len
(
args
)
+
1
}
."
)
raise
TypeError
(
f
"alloc_var expected at most 3 positional arguments but got
{
len
(
args
)
+
1
}
."
)
if
not
isinstance
(
parsed_scope
,
str
):
raise
TypeError
(
"Scope must be a string in alloc_var."
)
...
...
@@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
@
overload
def
empty
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
str
=
'float32'
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
...
def
empty
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
str
=
"float32"
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
...
def
empty
(
*
shape
:
Unpack
[
_Shapes
],
dtype
:
str
=
'float32'
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
empty
(
*
shape
:
Unpack
[
_Shapes
],
dtype
:
str
=
"float32"
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
if
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
(
tuple
,
list
)):
return
OutTensor
(
shape
[
0
],
dtype
)
elif
len
(
shape
)
==
2
and
isinstance
(
shape
[
0
],
(
tuple
,
list
))
and
isinstance
(
shape
[
1
],
str
):
...
...
@@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes],
elif
all
([
isinstance
(
x
,
(
int
,
PrimExpr
))
for
x
in
shape
]):
return
OutTensor
(
shape
,
dtype
)
else
:
raise
RuntimeError
(
f
'
Invalid shape
{
shape
}
'
)
raise
RuntimeError
(
f
"
Invalid shape
{
shape
}
"
)
tilelang/language/annotations.py
View file @
29051439
"""Annotation helpers exposed on the TileLang language surface."""
from
typing
import
Callable
from
tilelang.layout
import
Layout
...
...
tilelang/language/ast/__init__.py
View file @
29051439
...
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""Package tvm.script.ir_builder.tir"""
from
.ir
import
*
# noqa: F401
from
.ir
import
boolean
as
bool
# noqa: F401
from
.ir
import
buffer
as
Buffer
# noqa: F401
...
...
Prev
1
…
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment