Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
562 additions
and
560 deletions
+562
-560
tilelang/jit/__init__.py
tilelang/jit/__init__.py
+76
-84
tilelang/jit/adapter/base.py
tilelang/jit/adapter/base.py
+3
-7
tilelang/jit/adapter/ctypes/adapter.py
tilelang/jit/adapter/ctypes/adapter.py
+32
-31
tilelang/jit/adapter/cython/adapter.py
tilelang/jit/adapter/cython/adapter.py
+38
-41
tilelang/jit/adapter/libgen.py
tilelang/jit/adapter/libgen.py
+8
-11
tilelang/jit/adapter/nvrtc/__init__.py
tilelang/jit/adapter/nvrtc/__init__.py
+7
-7
tilelang/jit/adapter/nvrtc/adapter.py
tilelang/jit/adapter/nvrtc/adapter.py
+27
-25
tilelang/jit/adapter/nvrtc/libgen.py
tilelang/jit/adapter/nvrtc/libgen.py
+9
-11
tilelang/jit/adapter/nvrtc/wrapper.py
tilelang/jit/adapter/nvrtc/wrapper.py
+82
-64
tilelang/jit/adapter/torch/__init__.py
tilelang/jit/adapter/torch/__init__.py
+1
-1
tilelang/jit/adapter/torch/metal.py
tilelang/jit/adapter/torch/metal.py
+4
-8
tilelang/jit/adapter/tvm_ffi.py
tilelang/jit/adapter/tvm_ffi.py
+40
-40
tilelang/jit/adapter/utils.py
tilelang/jit/adapter/utils.py
+46
-66
tilelang/jit/adapter/wrapper.py
tilelang/jit/adapter/wrapper.py
+129
-97
tilelang/jit/execution_backend.py
tilelang/jit/execution_backend.py
+5
-2
tilelang/jit/kernel.py
tilelang/jit/kernel.py
+33
-38
tilelang/language/__init__.py
tilelang/language/__init__.py
+5
-1
tilelang/language/allocate.py
tilelang/language/allocate.py
+15
-26
tilelang/language/annotations.py
tilelang/language/annotations.py
+1
-0
tilelang/language/ast/__init__.py
tilelang/language/ast/__init__.py
+1
-0
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/jit/__init__.py
View file @
29051439
...
...
@@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
from
__future__
import
annotations
from
dataclasses
import
dataclass
...
...
@@ -39,17 +40,16 @@ from tqdm.auto import tqdm
logger
=
getLogger
(
__name__
)
_P
=
ParamSpec
(
'
_P
'
)
_KP
=
ParamSpec
(
'
_KP
'
)
_T
=
TypeVar
(
'
_T
'
)
_Ret
=
TypeVar
(
'
_Ret
'
)
_P
=
ParamSpec
(
"
_P
"
)
_KP
=
ParamSpec
(
"
_KP
"
)
_T
=
TypeVar
(
"
_T
"
)
_Ret
=
TypeVar
(
"
_Ret
"
)
def
compile
(
func
:
PrimFunc
[
_KP
,
_T
]
=
None
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
...
...
@@ -83,11 +83,9 @@ def compile(
if
isinstance
(
compile_flags
,
str
):
compile_flags
=
[
compile_flags
]
if
hasattr
(
func
,
'
out_idx_override
'
):
if
hasattr
(
func
,
"
out_idx_override
"
):
if
func
.
out_idx_override
is
not
None
and
out_idx
is
not
None
:
raise
ValueError
(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
raise
ValueError
(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
out_idx
=
func
.
out_idx_override
or
out_idx
# This path is not a performance critical path, so we can afford to convert the target.
...
...
@@ -96,6 +94,7 @@ def compile(
# Resolve execution backend (handles aliases, auto, validation per target)
requested_backend
=
execution_backend
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
execution_backend
=
resolve_execution_backend
(
requested_backend
,
target
)
if
verbose
:
allowed_now
=
allowed_backends_for_target
(
target
,
include_unavailable
=
False
)
...
...
@@ -119,17 +118,18 @@ def compile(
)
def
par_compile
(
funcs
:
Iterable
[
PrimFunc
[
_KP
,
_T
]],
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
def
par_compile
(
funcs
:
Iterable
[
PrimFunc
[
_KP
,
_T
]],
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
,
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
...
...
@@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
with
concurrent
.
futures
.
ThreadPoolExecutor
(
num_workers
,
'
tl-par-comp
'
)
as
executor
:
with
concurrent
.
futures
.
ThreadPoolExecutor
(
num_workers
,
"
tl-par-comp
"
)
as
executor
:
futures
=
[]
future_map
=
{}
for
i
,
func
in
enumerate
(
funcs
):
...
...
@@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
futures
.
append
(
future
)
results
=
[...
for
_
in
futures
]
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Parallel Compiling"
,
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
desc
=
"Parallel Compiling"
,
):
idx
=
future_map
[
future
]
if
ignore_error
:
...
...
@@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
@
dataclass
class
JITImpl
(
Generic
[
_P
,
_KP
,
_T
,
_Ret
]):
'''
"""
Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the
...
...
@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
PrimFunc and the resulting set is compiled in parallel via the
module-level `par_compile` helper. Returns a list of JITKernel objects
in the same order as the provided configs.
'''
"""
out_idx
:
list
[
int
]
|
int
|
None
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
...
...
@@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
assert
isinstance
(
tir
,
PrimFunc
),
f
"target function must be a PrimFunc but got
{
type
(
tir
)
}
"
return
tir
def
par_compile
(
self
,
configs
:
Iterable
[
dict
[
str
,
Any
]
|
tuple
[
str
,
Any
]],
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
def
par_compile
(
self
,
configs
:
Iterable
[
dict
[
str
,
Any
]
|
tuple
[
str
,
Any
]],
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
...
...
@@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
"""
configs
=
list
(
configs
)
funcs
=
[]
for
cfg
in
tqdm
(
configs
,
desc
=
'
Elaborating
'
):
for
cfg
in
tqdm
(
configs
,
desc
=
"
Elaborating
"
):
if
isinstance
(
cfg
,
tuple
):
funcs
.
append
(
self
.
get_tir
(
*
cfg
))
elif
isinstance
(
cfg
,
dict
):
...
...
@@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
pass_configs
=
self
.
pass_configs
,
compile_flags
=
self
.
compile_flags
,
num_workers
=
num_workers
,
ignore_error
=
ignore_error
)
ignore_error
=
ignore_error
,
)
def
compile
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
func
=
self
.
get_tir
(
*
args
,
**
kwargs
)
...
...
@@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
if
self
.
debug_root_path
:
if
isinstance
(
self
.
func
,
PrimFunc
):
func_name
=
self
.
func
.
attrs
[
'
global_symbol
'
]
func_name
=
self
.
func
.
attrs
[
"
global_symbol
"
]
else
:
func_name
=
getattr
(
self
.
func
,
'
__name__
'
,
'
jit_kernel
'
)
kernel_file
=
f
'
tilelang_jit_kernel_
{
func_name
}
.c
'
program_file
=
f
'
tilelang_jit_program_
{
func_name
}
.py
'
func_name
=
getattr
(
self
.
func
,
"
__name__
"
,
"
jit_kernel
"
)
kernel_file
=
f
"
tilelang_jit_kernel_
{
func_name
}
.c
"
program_file
=
f
"
tilelang_jit_program_
{
func_name
}
.py
"
makedirs
(
self
.
debug_root_path
,
exist_ok
=
True
)
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
'w'
)
as
f
:
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
"w"
)
as
f
:
print
(
kernel_result
.
get_kernel_source
(),
file
=
f
)
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
'w'
)
as
f
:
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
"w"
)
as
f
:
print
(
func
.
script
(),
file
=
f
)
return
kernel_result
def
parse_cache_key
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
return
self
.
func
.
func_annot
.
parse_key
(
*
args
,
**
kwargs
,
**
tune_params
)
else
:
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
key_args_tuple
=
args
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
tuned_key_kwargs_tuple
=
tuple
(
sorted
(
tune_params
.
items
()))
...
...
@@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
def
convert_kernel_args
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
return
self
.
func
.
func_annot
.
convert_to_kernel_args
(
*
args
,
**
kwargs
,
**
tune_params
)
else
:
raise
NotImplementedError
(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater."
)
raise
NotImplementedError
(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater."
)
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
# Separate out the tuning parameters from the user's kwargs
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments
=
kwargs
.
pop
(
'
__return_compile_arguments
'
,
False
)
return_compile_arguments
=
kwargs
.
pop
(
"
__return_compile_arguments
"
,
False
)
if
return_compile_arguments
:
logger
.
warning
(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
logger
.
warning
(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
compile_args
=
{
'
out_idx
'
:
self
.
out_idx
,
'
execution_backend
'
:
self
.
execution_backend
,
'
target
'
:
self
.
target
,
'
target_host
'
:
self
.
target_host
,
'
verbose
'
:
self
.
verbose
,
'
pass_configs
'
:
self
.
pass_configs
,
'
compile_flags
'
:
self
.
compile_flags
,
"
out_idx
"
:
self
.
out_idx
,
"
execution_backend
"
:
self
.
execution_backend
,
"
target
"
:
self
.
target
,
"
target_host
"
:
self
.
target_host
,
"
verbose
"
:
self
.
verbose
,
"
pass_configs
"
:
self
.
pass_configs
,
"
compile_flags
"
:
self
.
compile_flags
,
}
return
compile_args
key
=
self
.
parse_cache_key
(
*
args
,
**
kwargs
)
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
kernel
=
self
.
_kernel_cache
.
get
(
key
,
None
)
if
kernel
is
None
:
...
...
@@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr
@
overload
def
jit
(
func
:
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]])
->
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]:
...
def
jit
(
func
:
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]])
->
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]:
...
@
overload
...
...
@@ -448,22 +444,22 @@ def jit(
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
)
->
Callable
[[
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]]],
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]]:
...
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
)
->
Callable
[[
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]]],
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]]:
...
def
jit
(
# This is the new public interface
func
:
Callable
[
_P
,
_T
]
|
PrimFunc
|
None
=
None
,
*
,
# Indicates subsequent arguments are keyword-only
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
ExecutionBackend
=
"auto"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
):
func
:
Callable
[
_P
,
_T
]
|
PrimFunc
|
None
=
None
,
*
,
# Indicates subsequent arguments are keyword-only
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
ExecutionBackend
=
"auto"
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
...
...
@@ -516,7 +512,8 @@ def jit( # This is the new public interface
compile_flags
=
compile_flags
,
func_source
=
inspect
.
getsource
(
orig_func
),
signature
=
inspect
.
signature
(
orig_func
),
lazy_jit
=
False
)
lazy_jit
=
False
,
)
if
func
is
not
None
:
return
decorator
(
func
)
...
...
@@ -525,8 +522,7 @@ def jit( # This is the new public interface
@
overload
def
lazy_jit
(
func
:
Callable
[
_KP
,
_T
])
->
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]:
...
def
lazy_jit
(
func
:
Callable
[
_KP
,
_T
])
->
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]:
...
@
overload
...
...
@@ -539,9 +535,8 @@ def lazy_jit(
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
)
->
Callable
[[
Callable
[
_KP
,
_T
]],
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]]:
...
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
)
->
Callable
[[
Callable
[
_KP
,
_T
]],
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]]:
...
def
lazy_jit
(
...
...
@@ -555,7 +550,6 @@ def lazy_jit(
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
):
if
isinstance
(
compile_flags
,
str
):
compile_flags
=
[
compile_flags
]
...
...
@@ -567,7 +561,8 @@ def lazy_jit(
verbose
=
verbose
,
pass_configs
=
pass_configs
,
debug_root_path
=
debug_root_path
,
compile_flags
=
compile_flags
)
compile_flags
=
compile_flags
,
)
def
decorator
(
func
:
Callable
[
_P
,
_T
]):
pf
:
PrimFunc
[
_P
,
_T
]
|
PrimFuncCreater
[
_P
,
_T
]
=
prim_func
(
func
,
generator
=
True
)
...
...
@@ -576,10 +571,7 @@ def lazy_jit(
# return compile(pf, **compile_args)
# else:
return
JITImpl
(
func
=
pf
,
**
compile_args
,
func_source
=
inspect
.
getsource
(
pf
.
orig_func
),
signature
=
inspect
.
signature
(
pf
.
orig_func
),
lazy_jit
=
True
)
func
=
pf
,
**
compile_args
,
func_source
=
inspect
.
getsource
(
pf
.
orig_func
),
signature
=
inspect
.
signature
(
pf
.
orig_func
),
lazy_jit
=
True
)
return
decorator
(
func
)
if
func
is
not
None
else
decorator
tilelang/jit/adapter/base.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
...
...
@@ -8,7 +9,6 @@ import torch
class
BaseKernelAdapter
(
ABC
):
func
:
Callable
|
None
=
None
def
__init__
(
self
,
mod
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
])
->
None
:
...
...
@@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC):
result_idx
=
[]
elif
isinstance
(
result_idx
,
int
):
if
result_idx
>
len
(
params
)
or
result_idx
<
-
len
(
params
):
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
if
result_idx
<
0
:
result_idx
=
len
(
params
)
+
result_idx
result_idx
=
[
result_idx
]
elif
isinstance
(
result_idx
,
list
):
for
i
,
idx
in
enumerate
(
result_idx
):
if
idx
>=
len
(
params
)
or
idx
<
-
len
(
params
):
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
if
idx
<
0
:
result_idx
[
i
]
=
len
(
params
)
+
idx
else
:
...
...
tilelang/jit/adapter/ctypes/adapter.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
import
torch
from
..base
import
BaseKernelAdapter
...
...
@@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter):
param_dtypes
:
list
[
torch
.
dtype
]
|
None
=
None
# Cache for parameter dtypes
param_shapes
:
list
[
list
]
|
None
=
None
# Cache for parameter shapes
def
__init__
(
self
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
Args:
...
...
@@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
@
classmethod
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
@@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
)):
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
)):
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
...
...
@@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
self
.
lib
.
call
(
*
ctypes_args
)
...
...
@@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
@
property
def
is_dynamic
(
self
):
"""Indicates whether the kernel handles dynamic shapes."""
return
(
self
.
dynamic_symbolic_map
is
not
None
and
len
(
self
.
dynamic_symbolic_map
)
>
0
)
return
self
.
dynamic_symbolic_map
is
not
None
and
len
(
self
.
dynamic_symbolic_map
)
>
0
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
False
):
"""Returns the source code of the compiled kernel."""
...
...
tilelang/jit/adapter/cython/adapter.py
View file @
29051439
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
import
ctypes
import
logging
...
...
@@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
# Pass configs for the compiler
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
Args:
...
...
@@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
result
=
self
.
lib
.
init
()
if
result
!=
0
:
error_msg
=
self
.
lib
.
get_last_error
().
decode
(
'
utf-8
'
)
error_msg
=
self
.
lib
.
get_last_error
().
decode
(
"
utf-8
"
)
error_msg
+=
f
"
\n
{
self
.
lib_code
}
"
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
...
...
@@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
@
classmethod
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
@@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
result
=
adapter
.
lib
.
init
()
if
result
!=
0
:
error_msg
=
adapter
.
lib
.
get_last_error
().
decode
(
'
utf-8
'
)
error_msg
=
adapter
.
lib
.
get_last_error
().
decode
(
"
utf-8
"
)
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
adapter
.
cython_wrapper
=
CythonKernelWrapper
(
adapter
.
result_idx
,
adapter
.
params
,
adapter
.
lib
)
adapter
.
cython_wrapper
=
CythonKernelWrapper
(
adapter
.
result_idx
,
adapter
.
params
,
adapter
.
lib
)
adapter
.
cython_wrapper
.
set_dynamic_symbolic_map
(
adapter
.
dynamic_symbolic_map
)
adapter
.
cython_wrapper
.
set_buffer_dtype_map
(
adapter
.
buffer_dtype_map
)
adapter
.
cython_wrapper
.
set_static_shape_map
(
adapter
.
static_shape_map
)
...
...
@@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
)):
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
)):
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
...
...
@@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
params
=
func
.
params
ptr_map
=
{}
for
i
,
param
in
enumerate
(
params
):
if
param
.
dtype
==
'
handle
'
:
if
param
.
dtype
==
"
handle
"
:
ptr_map
[
i
]
=
param
.
name
return
ptr_map
def
_process_static_buffer_infos
(
self
)
->
\
tuple
[
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
list
[
tuple
[
tir
.
Var
]]]:
def
_process_static_buffer_infos
(
self
,
)
->
tuple
[
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
list
[
tuple
[
tir
.
Var
]]]:
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
...
...
@@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
self
.
lib
.
call
(
*
ctypes_args
)
...
...
@@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
skip_tensor_validation: Whether to skip tensor attributes validation which
includes shape, dtype, device, etc.
"""
return
self
.
cython_wrapper
.
forward
([
*
args
],
stream
=
stream
,
skip_tensor_validation
=
skip_tensor_validation
)
return
self
.
cython_wrapper
.
forward
([
*
args
],
stream
=
stream
,
skip_tensor_validation
=
skip_tensor_validation
)
return
lambda_forward
...
...
tilelang/jit/adapter/libgen.py
View file @
29051439
...
...
@@ -55,6 +55,7 @@ class LibraryGenerator:
verbose
=
self
.
verbose
if
is_cuda_target
(
target
):
from
tilelang.env
import
CUTLASS_INCLUDE_DIR
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
# noqa: SIM115
target_arch
=
get_target_arch
(
get_target_compute_version
(
target
))
libpath
=
src
.
name
.
replace
(
".cu"
,
".so"
)
...
...
@@ -65,15 +66,12 @@ class LibraryGenerator:
"TL_ENABLE_FAST_MATH"
,
"0.1.7"
,
)
enable_fast_math
=
not
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_DISABLE_FAST_MATH
,
True
)
enable_fast_math
=
not
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_DISABLE_FAST_MATH
,
True
)
else
:
enable_fast_math
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_FAST_MATH
,
False
)
ptxas_usage_level
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_PTXAS_REGISTER_USAGE_LEVEL
,
None
)
verbose_ptxas_output
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_PTXAS_VERBOSE_OUTPUT
,
False
)
ptxas_usage_level
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_PTXAS_REGISTER_USAGE_LEVEL
,
None
)
verbose_ptxas_output
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_PTXAS_VERBOSE_OUTPUT
,
False
)
command
=
[
get_nvcc_compiler
(),
...
...
@@ -102,6 +100,7 @@ class LibraryGenerator:
elif
is_hip_target
(
target
):
from
tilelang.env
import
COMPOSABLE_KERNEL_INCLUDE_DIR
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
rocm_path
=
find_rocm_path
()
...
...
@@ -119,6 +118,7 @@ class LibraryGenerator:
]
elif
is_cpu_target
(
target
):
from
tilelang.contrib.cc
import
get_cplus_compiler
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
...
...
@@ -134,9 +134,7 @@ class LibraryGenerator:
]
if
self
.
compile_flags
:
command
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
command
]
command
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
command
]
command
+=
[
"-o"
,
libpath
]
...
...
@@ -151,8 +149,7 @@ class LibraryGenerator:
raise
RuntimeError
(
f
"Compile kernel failed because of
{
e
}
"
)
from
e
if
ret
.
returncode
!=
0
:
raise
RuntimeError
(
f
"Compilation Failed!
{
command
}
"
f
"
\n
{
self
.
lib_code
}
"
)
raise
RuntimeError
(
f
"Compilation Failed!
{
command
}
\n
{
self
.
lib_code
}
"
)
self
.
srcpath
=
src
.
name
self
.
libpath
=
libpath
...
...
tilelang/jit/adapter/nvrtc/__init__.py
View file @
29051439
...
...
@@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
import
logging
__all__
=
[
'NVRTCKernelAdapter'
,
'TLNVRTCSourceWrapper'
,
'NVRTCLibraryGenerator'
,
'is_nvrtc_available'
,
'check_nvrtc_available'
]
__all__
=
[
"NVRTCKernelAdapter"
,
"TLNVRTCSourceWrapper"
,
"NVRTCLibraryGenerator"
,
"is_nvrtc_available"
,
"check_nvrtc_available"
]
logger
=
logging
.
getLogger
(
__name__
)
# Check if cuda-python is available
is_nvrtc_available
=
False
NVRTC_UNAVAILABLE_MESSAGE
=
(
"cuda-python is not available, NVRTC backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend."
)
NVRTC_UNAVAILABLE_MESSAGE
=
(
"cuda-python is not available, NVRTC backend cannot be used. "
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend."
)
try
:
import
cuda.bindings.driver
as
cuda
# noqa: F401
import
cuda.bindings.nvrtc
as
nvrtc
# noqa: F401
is_nvrtc_available
=
True
except
ImportError
as
e
:
logger
.
debug
(
f
"cuda-python import failed:
{
e
}
"
)
...
...
tilelang/jit/adapter/nvrtc/adapter.py
View file @
29051439
...
...
@@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
pymodule
=
None
kernels
=
{}
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
check_nvrtc_available
()
self
.
params
=
params
...
...
@@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
@
classmethod
def
from_database
(
cls
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
@@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
return
self
.
host_func
def
_forward_from_prebuild_lib
(
self
,
*
args
,
stream
:
int
|
None
=
None
):
"""Low-level function to call the compiled CUDA kernel.
"""
"""Low-level function to call the compiled CUDA kernel."""
return
self
.
pymodule
.
call
(
self
.
kernels
,
*
args
,
stream
=
stream
)
def
_wrap_forward_from_prebuild_lib
(
self
,
*
ins
:
list
[
torch
.
Tensor
],
stream
:
int
|
None
=
None
):
...
...
tilelang/jit/adapter/nvrtc/libgen.py
View file @
29051439
...
...
@@ -13,6 +13,7 @@ Key responsibilities:
- Load compiled cubin and extract kernel handles
- Manage library lifecycle (load/unload)
"""
from
__future__
import
annotations
import
importlib
import
logging
...
...
@@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
culib: CUDA library handle (CUlibrary)
pymodule: Imported Python module containing call() function
"""
host_func
:
str
|
None
=
None
culib
:
cuda
.
CUlibrary
|
None
=
None
pymodule
:
ModuleType
|
None
=
None
...
...
@@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator):
ctx
=
cuda
.
cuCtxGetCurrent
()[
1
]
if
cuda
.
cuCtxGetApiVersion
(
ctx
)[
0
]
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
import
torch
torch
.
cuda
.
synchronize
()
result
,
self
.
culib
=
cuda
.
cuLibraryLoadFromFile
(
bytes
(
lib_path
,
"utf-8"
),
[],
[],
0
,
[],
[],
0
)
result
,
self
.
culib
=
cuda
.
cuLibraryLoadFromFile
(
bytes
(
lib_path
,
"utf-8"
),
[],
[],
0
,
[],
[],
0
)
if
result
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
raise
RuntimeError
(
f
"Failed to load library:
{
lib_path
}
, error:
{
result
}
"
)
...
...
@@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator):
target
=
self
.
target
verbose
=
self
.
verbose
if
is_cuda_target
(
target
):
from
tilelang.env
import
(
CUDA_HOME
,
CUTLASS_INCLUDE_DIR
,
TILELANG_TEMPLATE_PATH
)
from
tilelang.env
import
CUDA_HOME
,
CUTLASS_INCLUDE_DIR
,
TILELANG_TEMPLATE_PATH
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
libpath
=
src
.
name
.
replace
(
".cu"
,
".cubin"
)
...
...
@@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator):
f
"-D__CUDACC_VER_MAJOR__=
{
__CUDACC_VER_MAJOR__
}
"
,
]
if
self
.
compile_flags
:
options
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
options
]
options
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
options
]
cubin_bytes
=
compile_cuda
(
self
.
lib_code
,
target_format
=
"cubin"
,
options
=
options
,
verbose
=
verbose
)
cubin_bytes
=
compile_cuda
(
self
.
lib_code
,
target_format
=
"cubin"
,
options
=
options
,
verbose
=
verbose
)
with
open
(
libpath
,
"wb"
)
as
f
:
f
.
write
(
cubin_bytes
)
...
...
@@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
self
.
libpath
=
libpath
self
.
pypath
=
src
.
name
.
replace
(
".cu"
,
".py"
)
if
self
.
host_func
is
None
:
raise
RuntimeError
(
"Host function is not set, please call update_host_func() first."
)
raise
RuntimeError
(
"Host function is not set, please call update_host_func() first."
)
with
open
(
self
.
pypath
,
"w"
)
as
f
:
f
.
write
(
self
.
host_func
)
else
:
...
...
tilelang/jit/adapter/nvrtc/wrapper.py
View file @
29051439
...
...
@@ -12,6 +12,7 @@ Key design:
- Dict-based deduplication ensures TMA descriptors created only once
- Generates pure Python using cuda.bindings.driver for zero C++ dependency
"""
from
__future__
import
annotations
from
typing
import
Any
,
ClassVar
...
...
@@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit
from
tilelang
import
tvm
as
tvm
from
tilelang.jit.adapter.wrapper
import
TLCUDASourceWrapper
from
tilelang.jit.adapter.utils
import
(
match_declare_kernel
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
)
from
tilelang.jit.adapter.utils
import
match_declare_kernel
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
PREDEF_HOST_FUNC_PY
=
"""
from cuda.bindings.driver import (
...
...
@@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
_generated_host_func
:
str
|
None
=
None
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
"""Initialize NVRTC wrapper with compiled IR modules.
Args:
...
...
@@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
"name"
:
buffer
.
data
.
name
,
"type"
:
"ctypes.c_void_p"
,
})
function_args
.
append
(
{
"name"
:
buffer
.
data
.
name
,
"type"
:
"ctypes.c_void_p"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
...
...
@@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return
(
f
"
{
name
}
.data_ptr()"
,
arg_type
)
return
(
name
,
arg_type
)
call_args
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
,
transform_nvrtc_arg
)
call_args
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
,
transform_nvrtc_arg
)
for
arg_name
,
arg_type
in
call_args
:
if
arg_type
==
"ctypes.c_void_p"
:
...
...
@@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
break
# Store kernel info for second pass
kernel_info_list
.
append
({
'function_name'
:
function_name
,
'block_info'
:
block_info
,
'grid_info'
:
grid_info
,
'dynamic_smem_buf'
:
dynamic_smem_buf
,
'call_args'
:
call_args
,
'device_index'
:
device_index
,
})
kernel_info_list
.
append
(
{
"function_name"
:
function_name
,
"block_info"
:
block_info
,
"grid_info"
:
grid_info
,
"dynamic_smem_buf"
:
dynamic_smem_buf
,
"call_args"
:
call_args
,
"device_index"
:
device_index
,
}
)
# Generate TMA descriptor initialization code once for all kernels
kernel_launch_code
+=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
# Second pass: generate kernel launch code for each kernel
for
kernel_info
in
kernel_info_list
:
function_name
=
kernel_info
[
'
function_name
'
]
block_info
=
kernel_info
[
'
block_info
'
]
grid_info
=
kernel_info
[
'
grid_info
'
]
dynamic_smem_buf
=
kernel_info
[
'
dynamic_smem_buf
'
]
call_args
=
kernel_info
[
'
call_args
'
]
device_index
=
kernel_info
[
'
device_index
'
]
function_name
=
kernel_info
[
"
function_name
"
]
block_info
=
kernel_info
[
"
block_info
"
]
grid_info
=
kernel_info
[
"
grid_info
"
]
dynamic_smem_buf
=
kernel_info
[
"
dynamic_smem_buf
"
]
call_args
=
kernel_info
[
"
call_args
"
]
device_index
=
kernel_info
[
"
device_index
"
]
arg_names
=
", "
.
join
([
arg
[
0
]
for
arg
in
call_args
])
arg_types
=
", "
.
join
([
arg
[
1
]
for
arg
in
call_args
])
...
...
@@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
kernel_launch_code
+=
init_l2_persistent_map
# Generate kernel launch code
kernel_launch_code
+=
KERNEL_LAUNCH_FUNC_PY
.
format
(
function_name
,
self
.
_pythonic_expr
(
grid_info
[
0
]),
self
.
_pythonic_expr
(
grid_info
[
1
]),
self
.
_pythonic_expr
(
grid_info
[
2
]),
self
.
_pythonic_expr
(
block_info
[
0
]),
self
.
_pythonic_expr
(
block_info
[
1
]),
self
.
_pythonic_expr
(
block_info
[
2
]),
smem_str
,
arg_names
,
arg_types
,
device_index
)
kernel_launch_code
+=
KERNEL_LAUNCH_FUNC_PY
.
format
(
function_name
,
self
.
_pythonic_expr
(
grid_info
[
0
]),
self
.
_pythonic_expr
(
grid_info
[
1
]),
self
.
_pythonic_expr
(
grid_info
[
2
]),
self
.
_pythonic_expr
(
block_info
[
0
]),
self
.
_pythonic_expr
(
block_info
[
1
]),
self
.
_pythonic_expr
(
block_info
[
2
]),
smem_str
,
arg_names
,
arg_types
,
device_index
,
)
# Reset L2 persistent map after all kernel execution
if
has_l2_persistent_map
:
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE_PY
# Wrap the kernel dispatch logic in an external C function
host_func
=
PREDEF_HOST_FUNC_PY
.
format
(
repr
(
list
(
function_informations
.
keys
())),
def_args
,
kernel_launch_code
)
host_func
=
PREDEF_HOST_FUNC_PY
.
format
(
repr
(
list
(
function_informations
.
keys
())),
def_args
,
kernel_launch_code
)
return
host_func
def
generate_l2_persistent_map
(
self
,
function_name
:
str
)
->
str
:
...
...
@@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
if
function_name
not
in
self
.
l2_persistent_map
:
return
""
init_l2_persistent_map
=
""
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
# Get persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
try
:
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
except
TypeError
:
# as size_in_bytes may be a symbolic expression
num_bytes
=
persisting_l2_cache_max_size
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC_PY
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC_PY
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
return
init_l2_persistent_map
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
"""Generate Python code to initialize TMA descriptors.
TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects
...
...
@@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return
tma_descriptor_init
# Parse TMA descriptor arguments using the common utility
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
# Generate Python code from parsed parameters
for
params
in
parsed_params
:
if
not
params
.
is_img2col
:
tma_descriptor_init
+=
TMA_DESC_INIT_FUNC_PY
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
box_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
else
:
tma_descriptor_init
+=
TMA_IM2COL_DESC_INIT_FUNC_PY
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
", "
.
join
(
params
.
lower_corner
),
", "
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
", "
.
join
(
params
.
lower_corner
),
", "
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
return
tma_descriptor_init
...
...
@@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
nonlocal
function_params
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
return
args
=
node
.
args
if
not
args
or
args
[
0
]
!=
fn
:
return
if
len
(
args
)
<
1
+
param_cnt
:
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
assert
function_params
is
not
None
,
"function_params should not be None"
...
...
tilelang/jit/adapter/torch/__init__.py
View file @
29051439
from
.metal
import
MetalKernelAdapter
__all__
=
[
'
MetalKernelAdapter
'
]
__all__
=
[
"
MetalKernelAdapter
"
]
tilelang/jit/adapter/torch/metal.py
View file @
29051439
...
...
@@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam
class
MetalKernelAdapter
(
BaseKernelAdapter
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
...
...
@@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter):
):
self
.
kernel_global_source
=
kernel_global_source
if
isinstance
(
func_or_mod
,
tir
.
PrimFunc
):
func_name
=
func_or_mod
.
attrs
[
'
global_symbol
'
]
func_name
=
func_or_mod
.
attrs
[
"
global_symbol
"
]
else
:
func_name
=
func_or_mod
.
__name__
self
.
kernel_name
=
func_name
+
'
_kernel
'
self
.
kernel_name
=
func_name
+
"
_kernel
"
self
.
verbose
=
verbose
self
.
block_info
=
[
1
,
1
,
1
]
...
...
@@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
for
var
,
func
in
device_mod
.
functions
.
items
():
assert
var
.
name_hint
==
self
.
kernel_name
thread_extent
=
func
.
attrs
[
'
thread_extent
'
]
thread_extent
=
func
.
attrs
[
"
thread_extent
"
]
for
tag
,
extent
in
thread_extent
.
items
():
if
"threadIdx"
in
tag
:
self
.
block_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
...
...
@@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self
.
grid_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
break
else
:
raise
AssertionError
(
f
'
no kernel with name
{
func_name
}
'
)
raise
AssertionError
(
f
"
no kernel with name
{
func_name
}
"
)
# print(self.block_info, self.grid_info)
super
().
__init__
(
func_or_mod
,
result_idx
=
result_idx
,
params
=
params
)
...
...
@@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter):
_kernel
=
None
def
_convert_torch_func
(
self
)
->
Callable
:
if
self
.
_kernel
is
None
:
_kernel
=
getattr
(
torch
.
mps
.
compile_shader
(
self
.
kernel_global_source
),
self
.
kernel_name
)
_threads
=
[
x
*
y
for
(
x
,
y
)
in
zip
(
self
.
block_info
,
self
.
grid_info
)]
@
wraps
(
_kernel
)
def
launcher
(
*
args
:
torch
.
Tensor
):
return
_kernel
(
*
args
,
threads
=
_threads
,
...
...
tilelang/jit/adapter/tvm_ffi.py
View file @
29051439
...
...
@@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked,
the execution observes the same stream context as the active Torch code.
On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
"""
from
__future__
import
annotations
from
typing
import
Callable
,
Any
...
...
@@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
- The stream pointer returned is a raw CUDA stream handle compatible with
TVM's device API; on CPU or when CUDA is unavailable, we return 0.
"""
# Class attributes to store compiled kernel information
target
:
str
|
Target
=
"cuda"
ir_module
:
tvm
.
IRModule
|
None
=
None
...
...
@@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map
:
dict
[
tir
.
Var
,
tuple
[
int
,
int
,
int
]]
|
None
=
None
# Stream/device functors are inherited from BaseKernelAdapter
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
rt_mod
:
tvm
.
runtime
.
Module
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
__init__
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
],
target
:
str
|
Target
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
rt_mod
:
tvm
.
runtime
.
Module
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
Args:
...
...
@@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
)):
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
)):
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
...
...
@@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
# Validate input count strictly
expected_inputs
=
len
(
self
.
params
)
-
len
(
self
.
result_idx
)
if
len
(
inputs
)
!=
expected_inputs
:
raise
ValueError
(
f
"Kernel expected
{
expected_inputs
}
inputs, but
{
len
(
inputs
)
}
are provided."
)
raise
ValueError
(
f
"Kernel expected
{
expected_inputs
}
inputs, but
{
len
(
inputs
)
}
are provided."
)
# Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device.
...
...
@@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
for
s
in
param_shapes
[
i
]:
if
isinstance
(
s
,
tir
.
Var
):
for
key
in
dynamic_symbolic_map
:
if
(
str
(
s
)
==
str
(
key
)):
ref_id
,
ref_tensor_idx
,
ref_shape_idx
=
dynamic_symbolic_map
[
key
]
if
str
(
s
)
==
str
(
key
):
ref_id
,
ref_tensor_idx
,
ref_shape_idx
=
dynamic_symbolic_map
[
key
]
if
ref_id
==
2
:
shape
.
append
(
inputs
[
ref_tensor_idx
])
elif
ref_id
==
0
:
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
shape
[
ref_shape_idx
])
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
shape
[
ref_shape_idx
])
elif
ref_id
==
1
:
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
stride
()[
ref_shape_idx
])
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
stride
()[
ref_shape_idx
])
else
:
# Already converted to Python int during initialization
shape
.
append
(
s
)
...
...
@@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
out_device
=
current_device_functor
()
if
len
(
shape
)
==
0
:
param_name
=
self
.
params
[
i
].
name
if
hasattr
(
self
.
params
[
i
],
'name'
)
else
f
'parameter_
{
i
}
'
param_name
=
self
.
params
[
i
].
name
if
hasattr
(
self
.
params
[
i
],
"name"
)
else
f
"parameter_
{
i
}
"
raise
ValueError
(
f
"Cannot create output tensor (name=
{
param_name
}
) - 0-dimensional tensors are not supported. "
f
"Expected shape:
{
shape
}
"
)
f
"Expected shape:
{
shape
}
"
)
tensor
=
torch
.
empty
(
*
shape
,
dtype
=
dtype
,
device
=
out_device
)
else
:
tensor
=
inputs
[
ins_idx
]
...
...
@@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
return
func
@
classmethod
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
):
def
from_database
(
cls
,
params
:
list
[
TensorType
],
result_idx
:
list
[
int
],
target
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
tilelang/jit/adapter/utils.py
View file @
29051439
...
...
@@ -70,7 +70,6 @@ def get_annotated_mod(
target_host
:
str
|
Target
|
None
=
None
,
model_type
:
Literal
[
"device"
,
"host"
,
"all"
]
=
"all"
,
)
->
IRModule
|
tuple
[
IRModule
,
IRModule
]:
# Validate model_type early
if
model_type
not
in
{
"device"
,
"host"
,
"all"
}:
raise
ValueError
(
f
"Invalid model type:
{
model_type
}
"
)
...
...
@@ -95,21 +94,15 @@ def get_annotated_mod(
# Define dispatch dictionary for different model types
dispatch
=
{
"device"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
"host"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_host_call
)(
m
),
"all"
:
lambda
m
:
(
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
tir
.
transform
.
Filter
(
_is_host_call
)
(
m
)),
"device"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
"host"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_host_call
)(
m
),
"all"
:
lambda
m
:
(
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
tir
.
transform
.
Filter
(
_is_host_call
)(
m
)),
}
return
dispatch
[
model_type
](
mod
)
def
pythonic_expr
(
expr
:
tvm
.
tir
.
PrimExpr
,
dtype_map
:
dict
[
str
,
str
]
|
None
=
None
,
ignore_cast
:
bool
=
False
)
->
str
:
def
pythonic_expr
(
expr
:
tvm
.
tir
.
PrimExpr
,
dtype_map
:
dict
[
str
,
str
]
|
None
=
None
,
ignore_cast
:
bool
=
False
)
->
str
:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
...
...
@@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
s
=
f
"(
{
type_str
}
)
{
value_str
}
"
p
=
PRECEDENCE
.
get
(
type
(
node
),
ATOMIC_PRECEDENCE
)
elif
isinstance
(
node
,
(
tvm
.
tir
.
Mul
,
tvm
.
tir
.
FloorDiv
,
tvm
.
tir
.
Add
,
tvm
.
tir
.
Sub
,
tvm
.
tir
.
FloorMod
,
tvm
.
tir
.
LT
,
tvm
.
tir
.
LE
,
tvm
.
tir
.
GT
,
tvm
.
tir
.
GE
,
tvm
.
tir
.
EQ
,
tvm
.
tir
.
NE
,
tvm
.
tir
.
And
,
tvm
.
tir
.
Or
)):
node
,
(
tvm
.
tir
.
Mul
,
tvm
.
tir
.
FloorDiv
,
tvm
.
tir
.
Add
,
tvm
.
tir
.
Sub
,
tvm
.
tir
.
FloorMod
,
tvm
.
tir
.
LT
,
tvm
.
tir
.
LE
,
tvm
.
tir
.
GT
,
tvm
.
tir
.
GE
,
tvm
.
tir
.
EQ
,
tvm
.
tir
.
NE
,
tvm
.
tir
.
And
,
tvm
.
tir
.
Or
,
),
):
op_map
=
{
tvm
.
tir
.
Mul
:
"*"
,
tvm
.
tir
.
FloorDiv
:
"/"
,
...
...
@@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
return
next
(
iter
(
node_to_result_map
[
expr
]),
""
)
def
maybe_desc_name
(
name
:
str
,
matches
:
list
[
str
],
i
:
int
,
desc_name_map
:
dict
[
str
,
str
]
|
None
=
None
)
->
bool
:
def
maybe_desc_name
(
name
:
str
,
matches
:
list
[
str
],
i
:
int
,
desc_name_map
:
dict
[
str
,
str
]
|
None
=
None
)
->
bool
:
"""
Check if a parameter name corresponds to a TMA descriptor.
...
...
@@ -290,8 +294,7 @@ def parse_function_call_args(
else
:
call_args
.
append
(
match
)
if
desc_name_var_map
is
not
None
and
function_params
is
not
None
:
assert
len
(
call_args
)
<=
len
(
function_params
),
\
f
"Too many arguments:
{
len
(
call_args
)
}
>
{
len
(
function_params
)
}
"
assert
len
(
call_args
)
<=
len
(
function_params
),
f
"Too many arguments:
{
len
(
call_args
)
}
>
{
len
(
function_params
)
}
"
desc_name_var_map
[
match
]
=
function_params
[
len
(
call_args
)
-
1
]
return
call_args
...
...
@@ -300,12 +303,7 @@ def parse_function_call_args(
class
TMADescriptorParams
:
"""Parsed TMA descriptor parameters."""
def
__init__
(
self
,
handle_name
:
str
,
dtype
:
str
,
tensor_rank
:
int
,
global_address
:
Any
,
is_img2col
:
bool
=
False
):
def
__init__
(
self
,
handle_name
:
str
,
dtype
:
str
,
tensor_rank
:
int
,
global_address
:
Any
,
is_img2col
:
bool
=
False
):
self
.
handle_name
=
handle_name
self
.
dtype
=
dtype
self
.
tensor_rank
=
tensor_rank
...
...
@@ -355,22 +353,19 @@ def parse_tma_descriptor_args(
results
=
[]
for
handle_name
,
_
in
desc_name_map
.
items
():
assert
handle_name
in
desc_name_var_map
,
\
f
"Handle name
{
handle_name
}
not found in desc_name_var_map"
assert
handle_name
in
desc_name_var_map
,
f
"Handle name
{
handle_name
}
not found in desc_name_var_map"
desc_var
=
desc_name_var_map
[
handle_name
]
assert
desc_var
in
tma_descriptor_args
,
\
f
"TMA descriptor
{
desc_var
}
not found in
{
tma_descriptor_args
}
"
assert
desc_var
in
tma_descriptor_args
,
f
"TMA descriptor
{
desc_var
}
not found in
{
tma_descriptor_args
}
"
args
=
tma_descriptor_args
[
desc_var
]
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
if
len
(
args
)
<
3
:
raise
ValueError
(
f
"TMA descriptor args too short:
{
len
(
args
)
}
elements, expected at least 3"
)
raise
ValueError
(
f
"TMA descriptor args too short:
{
len
(
args
)
}
elements, expected at least 3"
)
tma_create_str
,
_
,
dtype
,
tensor_rank
,
global_address
,
*
remaining_args
=
args
is_img2col
=
(
tma_create_str
.
value
==
"__tvm_tensormap_create_im2col"
)
is_img2col
=
tma_create_str
.
value
==
"__tvm_tensormap_create_im2col"
# Convert basic fields
dtype
=
pythonic_expr_func
(
dtype
)
...
...
@@ -386,60 +381,45 @@ def parse_tma_descriptor_args(
# Tiled mode
expected_args_len
=
4
*
tensor_rank
+
4
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
]
params
.
box_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]
]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]]
params
.
box_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]]
# Extract remaining parameters
try
:
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
4
*
tensor_rank
:
4
*
tensor_rank
+
4
]
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
4
*
tensor_rank
:
4
*
tensor_rank
+
4
]
params
.
interleave
=
pythonic_expr_func
(
interleave
)
params
.
swizzle
=
pythonic_expr_func
(
swizzle
)
params
.
l2_promotion
=
pythonic_expr_func
(
l2_promotion
)
params
.
oob_fill
=
pythonic_expr_func
(
oob_fill
)
except
ValueError
as
e
:
raise
ValueError
(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
raise
ValueError
(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
else
:
# Im2col mode
expected_args_len
=
5
*
tensor_rank
+
2
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
]
params
.
lower_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
-
2
]
]
params
.
upper_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
4
*
tensor_rank
-
2
:
5
*
tensor_rank
-
4
]
]
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]]
params
.
lower_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
-
2
]]
params
.
upper_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
4
*
tensor_rank
-
2
:
5
*
tensor_rank
-
4
]]
# Extract remaining parameters
try
:
smem_box_pixel
,
smem_box_channel
,
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
\
remaining_args
[
5
*
tensor_rank
-
4
:
5
*
tensor_rank
+
2
]
smem_box_pixel
,
smem_box_channel
,
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
5
*
tensor_rank
-
4
:
5
*
tensor_rank
+
2
]
params
.
smem_box_pixel
=
pythonic_expr_func
(
smem_box_pixel
)
params
.
smem_box_channel
=
pythonic_expr_func
(
smem_box_channel
)
params
.
interleave
=
pythonic_expr_func
(
interleave
)
...
...
tilelang/jit/adapter/wrapper.py
View file @
29051439
...
...
@@ -4,9 +4,18 @@ from tilelang import tvm as tvm
from
typing
import
Any
from
tvm
import
IRModule
from
tvm.target
import
Target
from
.utils
import
(
is_metal_target
,
match_declare_kernel
,
match_declare_kernel_cpu
,
is_cuda_target
,
is_hip_target
,
is_cpu_target
,
get_annotated_mod
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
)
from
.utils
import
(
is_metal_target
,
match_declare_kernel
,
match_declare_kernel_cpu
,
is_cuda_target
,
is_hip_target
,
is_cpu_target
,
get_annotated_mod
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
,
)
import
re
import
logging
import
textwrap
...
...
@@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """
class
BaseWrapper
(
ABC
):
@
abstractmethod
def
wrap
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
...
...
@@ -163,13 +171,15 @@ class TLCUDASourceWrapper:
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
source
=
source
...
...
@@ -211,15 +221,16 @@ class TLCUDASourceWrapper:
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
"name"
:
buffer
.
data
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"* __restrict__"
,
})
function_args
.
append
(
{
"name"
:
buffer
.
data
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"* __restrict__"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
...
...
@@ -256,38 +267,40 @@ class TLCUDASourceWrapper:
# Identify the start of the function body to insert arguments
index
=
code
.
index
(
"{"
,
index
)
block_str
=
f
"dim3(
{
self
.
_pythonic_expr
(
block_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
2
])
}
)"
grid_str
=
f
"dim3(
{
self
.
_pythonic_expr
(
grid_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
2
])
}
)"
block_str
=
(
f
"dim3(
{
self
.
_pythonic_expr
(
block_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
2
])
}
)"
)
grid_str
=
(
f
"dim3(
{
self
.
_pythonic_expr
(
grid_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
2
])
}
)"
)
smem_str
=
0
if
dynamic_smem_buf
is
None
else
dynamic_smem_buf
init_l2_persistent_map
=
self
.
generate_l2_persistent_map
(
function_name
)
kernel_launch_code
+=
init_l2_persistent_map
if
self
.
use_cooperative_groups
[
function_name
]:
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
(
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
)
args_array
=
[
f
"(void*)&
{
arg
}
"
for
arg
in
args_list
]
call_args
=
f
"
\t
void*
{
function_name
}
_args[] = {{
{
', '
.
join
(
args_array
)
}
}};
\n
"
kernel_launch_code
+=
call_args
# Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code
+=
"
\t
TILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));
\n
"
.
format
(
function_name
,
grid_str
,
block_str
,
function_name
+
"_args"
,
smem_str
)
function_name
,
grid_str
,
block_str
,
function_name
+
"_args"
,
smem_str
)
else
:
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
(
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
)
call_args
=
", "
.
join
(
args_list
)
kernel_launch_code
+=
f
"
\t
{
function_name
}
<<<
{
grid_str
}
,
{
block_str
}
,
{
smem_str
}
, stream>>>(
{
call_args
}
);
\n
"
kernel_launch_code
+=
f
"
\t
TILELANG_CHECK_LAST_ERROR(
\
"
{
function_name
}
\
"
);
\n
"
kernel_launch_code
+=
f
'
\t
TILELANG_CHECK_LAST_ERROR("
{
function_name
}
");
\n
'
if
has_l2_persistent_map
:
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE
init_tma_descriptor_args
=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
init_tma_descriptor_args
=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
kernel_launch_code
=
init_tma_descriptor_args
+
kernel_launch_code
# Wrap the kernel dispatch logic in an external C function
...
...
@@ -298,46 +311,63 @@ class TLCUDASourceWrapper:
if
function_name
not
in
self
.
l2_persistent_map
:
return
""
init_l2_persistent_map
=
""
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
# get persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
try
:
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
except
Exception
:
# as size_in_bytes maybe a symbolic expression
num_bytes
=
persisting_l2_cache_max_size
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
return
init_l2_persistent_map
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
tma_descripter_init
=
""
if
self
.
tma_descriptor_args
is
None
:
return
tma_descripter_init
# Parse TMA descriptor arguments using the common utility
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
# Generate C++ code from parsed parameters
for
params
in
parsed_params
:
if
not
params
.
is_img2col
:
tma_descripter_init
+=
TMA_DESC_INIT_FUNC
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
box_dim
),
","
.
join
(
params
.
element_strides
),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
box_dim
),
","
.
join
(
params
.
element_strides
),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
else
:
tma_descripter_init
+=
TMA_IM2COL_DESC_INIT_FUNC
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
element_strides
),
","
.
join
(
params
.
lower_corner
),
","
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
element_strides
),
","
.
join
(
params
.
lower_corner
),
","
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
return
tma_descripter_init
...
...
@@ -347,9 +377,8 @@ class TLCUDASourceWrapper:
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
self
.
device_mod
=
device_mod
self
.
host_mod
=
host_mod
assert
(
len
(
self
.
device_mod
.
functions
)
>=
1
),
"Device module should have at least one function."
assert
(
len
(
self
.
host_mod
.
functions
)
==
1
),
"Only support one function in host module."
assert
len
(
self
.
device_mod
.
functions
)
>=
1
,
"Device module should have at least one function."
assert
len
(
self
.
host_mod
.
functions
)
==
1
,
"Only support one function in host module."
block_info_map
=
{}
grid_info_map
=
{}
...
...
@@ -438,8 +467,7 @@ class TLCUDASourceWrapper:
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
if
dynamic_smem_buf
is
not
None
:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
.
format
(
function_name
,
dynamic_smem_buf
)
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
.
format
(
function_name
,
dynamic_smem_buf
)
# Format the initialization function using the call_str
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
return
init_funcs
...
...
@@ -466,17 +494,14 @@ class TLCUDASourceWrapper:
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
nonlocal
function_params
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
return
args
=
node
.
args
if
not
args
or
args
[
0
]
!=
fn
:
return
if
len
(
args
)
<
1
+
param_cnt
:
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params
=
args
[
1
:
1
+
param_cnt
]
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
assert
function_params
is
not
None
,
"function_params should not be None"
...
...
@@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"uchar"
:
"uint8_t"
,
}
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
super
().
__init__
(
scheduled_ir_module
,
source
,
target
,
device_mod
,
host_mod
,
pass_configs
)
def
get_init_func
(
self
):
...
...
@@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
if
dynamic_smem_buf
is
not
None
:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
.
format
(
function_name
,
dynamic_smem_buf
)
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
.
format
(
function_name
,
dynamic_smem_buf
)
# Format the initialization function using the call_str
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
return
init_funcs
...
...
@@ -623,13 +649,15 @@ class TLCPUSourceWrapper:
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
source
=
source
...
...
@@ -658,15 +686,16 @@ class TLCPUSourceWrapper:
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
"name"
:
buffer
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"*"
,
})
function_args
.
append
(
{
"name"
:
buffer
.
name
,
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"*"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
function_args
.
append
({
"name"
:
dyn_sym
,
"type"
:
self
.
_lookup_type
(
dyn_sym_dtype
)})
...
...
@@ -686,7 +715,6 @@ class TLCPUSourceWrapper:
_call_str
=
""""""
for
function_name
,
_
in
function_informations
.
items
():
# Find the location of the global kernel function in the code
index
=
match_declare_kernel_cpu
(
code
,
function_name
+
"("
)
...
...
@@ -706,8 +734,8 @@ class TLCPUSourceWrapper:
def
parse_source_information
(
self
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
,
config
=
self
.
pass_configs
):
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
assert
(
len
(
device_mod
.
functions
)
>=
1
)
,
"Device module should have at least one function."
assert
(
len
(
host_mod
.
functions
)
==
1
)
,
"Only support one function in host module."
assert
len
(
device_mod
.
functions
)
>=
1
,
"Device module should have at least one function."
assert
len
(
host_mod
.
functions
)
==
1
,
"Only support one function in host module."
function_names
=
[]
for
g_var
,
_
in
device_mod
.
functions
.
items
():
...
...
@@ -767,14 +795,15 @@ class TLCPUSourceWrapper:
class
TLMetalSourceWrapper
:
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
def
__init__
(
self
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
source
=
source
...
...
@@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
"""
device_mod
:
IRModule
|
None
=
None
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
...
...
@@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper):
target
=
self
.
target
,
device_mod
=
self
.
device_mod
,
host_mod
=
self
.
host_mod
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
return
wrapper
.
lib_code
class
TLPyWrapper
(
TLWrapper
):
def
__init__
(
self
,
target
:
Target
):
super
().
__init__
(
target
)
...
...
@@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if
is_cuda_target
(
self
.
target
):
from
tilelang.jit.adapter.nvrtc
import
TLNVRTCSourceWrapper
wrapper_class
=
TLNVRTCSourceWrapper
else
:
raise
ValueError
(
f
"Unsupported target for NVRTC backend:
{
self
.
target
}
"
)
...
...
@@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper):
target
=
self
.
target
,
device_mod
=
self
.
device_mod
,
host_mod
=
self
.
host_mod
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
return
wrapper
.
host_func
,
wrapper
.
function_names
tilelang/jit/execution_backend.py
View file @
29051439
...
...
@@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
# Drop NVRTC if not importable
try
:
from
tilelang.jit.adapter.nvrtc
import
is_nvrtc_available
# lazy
if
not
is_nvrtc_available
and
"nvrtc"
in
allowed
:
allowed
=
[
b
for
b
in
allowed
if
b
!=
"nvrtc"
]
except
Exception
:
...
...
@@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
if
req
not
in
allowed_all
:
raise
ValueError
(
f
"Invalid execution backend '
{
requested
}
' for target '
{
_target_kind
(
target
)
}
'. "
f
"Allowed:
{
_format_options
(
allowed_all
)
}
. Tip: use execution_backend='auto'."
)
f
"Allowed:
{
_format_options
(
allowed_all
)
}
. Tip: use execution_backend='auto'."
)
# Promote to availability-aware set for nicer errors (e.g., nvrtc not installed)
if
req
not
in
allowed_avail
:
raise
ValueError
(
f
"Execution backend '
{
requested
}
' requires extra dependencies and is not available now. "
f
"Try one of:
{
_format_options
(
allowed_avail
)
}
."
)
f
"Try one of:
{
_format_options
(
allowed_avail
)
}
."
)
return
req
tilelang/jit/kernel.py
View file @
29051439
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Generic
,
Literal
,
TypeVar
# Python 3.9 compatibility for ParamSpec
try
:
from
typing
import
ParamSpec
...
...
@@ -14,8 +15,7 @@ import tilelang
from
tilelang
import
tvm
from
tilelang
import
env
from
tilelang.engine.param
import
CompiledArtifact
,
KernelParam
from
tilelang.jit.adapter
import
(
BaseKernelAdapter
,
CtypesKernelAdapter
,
CythonKernelAdapter
,
TVMFFIKernelAdapter
,
MetalKernelAdapter
)
from
tilelang.jit.adapter
import
BaseKernelAdapter
,
CtypesKernelAdapter
,
CythonKernelAdapter
,
TVMFFIKernelAdapter
,
MetalKernelAdapter
from
tilelang.profiler
import
Profiler
,
TensorSupplyType
from
tilelang.utils.target
import
determine_target
from
tilelang.contrib
import
nvcc
as
tl_nvcc
...
...
@@ -24,8 +24,8 @@ import os
logger
=
logging
.
getLogger
(
__name__
)
_P
=
ParamSpec
(
'
_P
'
)
_T
=
TypeVar
(
'
_T
'
)
_P
=
ParamSpec
(
"
_P
"
)
_T
=
TypeVar
(
"
_T
"
)
class
JITKernel
(
Generic
[
_P
,
_T
]):
...
...
@@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]):
torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function.
"""
prim_func
:
PrimFunc
=
None
artifact
:
CompiledArtifact
=
None
adapter
:
BaseKernelAdapter
=
None
...
...
@@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]):
if
execution_backend
==
"cython"
:
from
tilelang.contrib.cc
import
get_cplus_compiler
assert
(
get_cplus_compiler
()
is
not
None
),
"Cython backend requires a C++ compiler, please install or use other backends."
assert
get_cplus_compiler
()
is
not
None
,
"Cython backend requires a C++ compiler, please install or use other backends."
if
from_database
:
return
...
...
@@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]):
"""
return
self
.
torch_function
(
*
args
,
**
kwds
)
def
_compile_and_create_adapter
(
self
,
tilelang_func
:
PrimFunc
,
out_idx
:
list
[
int
])
->
BaseKernelAdapter
:
def
_compile_and_create_adapter
(
self
,
tilelang_func
:
PrimFunc
,
out_idx
:
list
[
int
])
->
BaseKernelAdapter
:
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
...
...
@@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]):
target
=
target
,
target_host
=
target_host
,
enable_host_codegen
=
enable_host_codegen
,
enable_device_compile
=
enable_device_compile
)
enable_device_compile
=
enable_device_compile
,
)
self
.
artifact
=
artifact
...
...
@@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]):
if
execution_backend
==
"tvm_ffi"
:
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
# But we need to ensure that the runtime is enabled and the runtime module is not None.
assert
(
artifact
.
rt_mod
is
not
None
)
,
"tvm_ffi backend requires a runtime module."
assert
artifact
.
rt_mod
is
not
None
,
"tvm_ffi backend requires a runtime module."
adapter
=
TVMFFIKernelAdapter
(
params
=
artifact
.
params
,
result_idx
=
out_idx
,
...
...
@@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]):
)
elif
execution_backend
==
"nvrtc"
:
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
adapter
=
NVRTCKernelAdapter
(
params
=
artifact
.
params
,
result_idx
=
out_idx
,
...
...
@@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]):
return
adapter
def
_create_adapter_from_database
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
]
|
int
,
target
:
str
|
Target
,
func_or_mod
:
PrimFunc
|
tvm
.
runtime
.
Module
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
)
->
BaseKernelAdapter
:
def
_create_adapter_from_database
(
self
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
]
|
int
,
target
:
str
|
Target
,
func_or_mod
:
PrimFunc
|
tvm
.
runtime
.
Module
,
host_kernel_source
:
str
,
device_kernel_source
:
str
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
)
->
BaseKernelAdapter
:
target
=
self
.
target
execution_backend
=
self
.
execution_backend
...
...
@@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]):
)
elif
execution_backend
==
"nvrtc"
:
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
adapter
=
NVRTCKernelAdapter
.
from_database
(
params
=
params
,
result_idx
=
result_idx
,
...
...
@@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]):
"""
return
cls
(
func
=
tilelang_func
,
**
kwargs
)
def
get_profiler
(
self
,
tensor_supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Auto
)
->
Profiler
:
def
get_profiler
(
self
,
tensor_supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Auto
)
->
Profiler
:
"""
Creates a profiler to benchmark the compiled runtime module.
...
...
@@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]):
Profiler
A Profiler instance for benchmarking the runtime module.
"""
return
Profiler
(
self
.
params
,
self
.
out_idx
,
tensor_supply_type
).
with_default_adapter
(
self
.
adapter
)
return
Profiler
(
self
.
params
,
self
.
out_idx
,
tensor_supply_type
).
with_default_adapter
(
self
.
adapter
)
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
True
)
->
str
:
"""
...
...
@@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]):
dir_path
=
os
.
path
.
dirname
(
kernel_path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
kernel_path
,
'w'
)
as
f
:
with
open
(
kernel_path
,
"w"
)
as
f
:
f
.
write
(
self
.
get_kernel_source
())
if
host_path
is
not
None
:
dir_path
=
os
.
path
.
dirname
(
host_path
)
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
host_path
,
'w'
)
as
f
:
with
open
(
host_path
,
"w"
)
as
f
:
f
.
write
(
self
.
get_host_source
())
except
Exception
as
e
:
logger
.
error
(
f
"Failed to export sources:
{
e
}
"
)
# Backward compatibility alias (deprecated)
def
print_source_code
(
self
,
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
,
file
:
str
|
None
=
None
)
->
None
:
def
print_source_code
(
self
,
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
,
file
:
str
|
None
=
None
)
->
None
:
"""
Deprecated: use show_source() or export_sources() instead.
...
...
@@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]):
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
"""
logger
.
warning
(
"print_source_code is deprecated; use show_source() or export_sources() instead."
)
logger
.
warning
(
"print_source_code is deprecated; use show_source() or export_sources() instead."
)
if
file
is
not
None
:
# Historical behavior wrote only kernel source when file provided
self
.
export_sources
(
kernel_path
=
file
)
else
:
self
.
show_source
(
which
=
which
)
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
ref_latency
:
float
)
->
JITKernel
:
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
ref_latency
:
float
)
->
JITKernel
:
"""
Updates the tuning results for this kernel.
...
...
@@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]):
verbose
=
self
.
verbose
# Ensure target is set so nvcc picks correct arch via Target.current()
with
self
.
target
:
return
tl_nvcc
.
get_ptx_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
return
tl_nvcc
.
get_ptx_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_ptx
(
self
)
->
None
:
"""
...
...
@@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]):
if
verbose
is
None
:
verbose
=
self
.
verbose
with
self
.
target
:
return
tl_nvcc
.
get_sass_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
return
tl_nvcc
.
get_sass_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_sass
(
self
)
->
None
:
"""
...
...
tilelang/language/__init__.py
View file @
29051439
"""The language interface for tl programs."""
from
__future__
import
annotations
# from .parser import *
...
...
@@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401
from
.symbolics
import
dynamic
,
symbolic
# noqa: F401
from
.annotations
import
(
# noqa: F401
use_swizzle
,
annotate_layout
,
annotate_safe_value
,
annotate_l2_hit_ratio
,
use_swizzle
,
annotate_layout
,
annotate_safe_value
,
annotate_l2_hit_ratio
,
)
...
...
tilelang/language/allocate.py
View file @
29051439
...
...
@@ -13,8 +13,10 @@ Available allocation functions:
Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
"""
from
__future__
import
annotations
from
typing
import
TypeVar
,
overload
,
Literal
,
Callable
# Python 3.9 compatibility for advanced typing features (PEP 646)
try
:
from
typing
import
TypeVarTuple
,
Unpack
# type: ignore[attr-defined]
...
...
@@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype
from
.v2.builder
import
OutTensor
from
.v2.annot
import
Tensor
,
SharedBuffer
,
LocalBuffer
,
FragmentBuffer
_Shapes
=
TypeVarTuple
(
'
_Shapes
'
)
_DType
=
TypeVar
(
'
_DType
'
)
_Shapes
=
TypeVarTuple
(
"
_Shapes
"
)
_DType
=
TypeVar
(
"
_DType
"
)
def
alloc_shared
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"shared.dyn"
)
->
SharedBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
alloc_shared
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"shared.dyn"
)
->
SharedBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a shared memory buffer for inter-thread communication.
Args:
...
...
@@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]],
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
def
alloc_local
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local"
)
->
LocalBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
alloc_local
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local"
)
->
LocalBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a local memory buffer for thread-private storage.
Args:
...
...
@@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]],
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
def
alloc_fragment
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local.fragment"
)
->
FragmentBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
alloc_fragment
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local.fragment"
)
->
FragmentBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a fragment memory buffer for specialized operations.
Args:
...
...
@@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]],
@
overload
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
'local.var'
)
->
Buffer
:
...
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
"local.var"
)
->
Buffer
:
...
@
overload
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
'local.var'
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
"local.var"
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
...
...
@@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
raise
TypeError
(
"Scope must be provided as a string in alloc_var."
)
parsed_scope
=
parsed_scope_arg
elif
len
(
args
)
>
2
:
raise
TypeError
(
f
"alloc_var expected at most 3 positional arguments but got
{
len
(
args
)
+
1
}
."
)
raise
TypeError
(
f
"alloc_var expected at most 3 positional arguments but got
{
len
(
args
)
+
1
}
."
)
if
not
isinstance
(
parsed_scope
,
str
):
raise
TypeError
(
"Scope must be a string in alloc_var."
)
...
...
@@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
@
overload
def
empty
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
str
=
'float32'
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
...
def
empty
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
str
=
"float32"
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
...
def
empty
(
*
shape
:
Unpack
[
_Shapes
],
dtype
:
str
=
'float32'
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
def
empty
(
*
shape
:
Unpack
[
_Shapes
],
dtype
:
str
=
"float32"
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
if
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
(
tuple
,
list
)):
return
OutTensor
(
shape
[
0
],
dtype
)
elif
len
(
shape
)
==
2
and
isinstance
(
shape
[
0
],
(
tuple
,
list
))
and
isinstance
(
shape
[
1
],
str
):
...
...
@@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes],
elif
all
([
isinstance
(
x
,
(
int
,
PrimExpr
))
for
x
in
shape
]):
return
OutTensor
(
shape
,
dtype
)
else
:
raise
RuntimeError
(
f
'
Invalid shape
{
shape
}
'
)
raise
RuntimeError
(
f
"
Invalid shape
{
shape
}
"
)
tilelang/language/annotations.py
View file @
29051439
"""Annotation helpers exposed on the TileLang language surface."""
from
typing
import
Callable
from
tilelang.layout
import
Layout
...
...
tilelang/language/ast/__init__.py
View file @
29051439
...
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""Package tvm.script.ir_builder.tir"""
from
.ir
import
*
# noqa: F401
from
.ir
import
boolean
as
bool
# noqa: F401
from
.ir
import
buffer
as
Buffer
# noqa: F401
...
...
Prev
1
…
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment