Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
29051439
Unverified
Commit
29051439
authored
Dec 12, 2025
by
Lei Wang
Committed by
GitHub
Dec 12, 2025
Browse files
[Lint] Phaseout Yapf format and embrace ruff format (#1417)
parent
e84b24bc
Changes
426
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
562 additions
and
560 deletions
+562
-560
tilelang/jit/__init__.py
tilelang/jit/__init__.py
+76
-84
tilelang/jit/adapter/base.py
tilelang/jit/adapter/base.py
+3
-7
tilelang/jit/adapter/ctypes/adapter.py
tilelang/jit/adapter/ctypes/adapter.py
+32
-31
tilelang/jit/adapter/cython/adapter.py
tilelang/jit/adapter/cython/adapter.py
+38
-41
tilelang/jit/adapter/libgen.py
tilelang/jit/adapter/libgen.py
+8
-11
tilelang/jit/adapter/nvrtc/__init__.py
tilelang/jit/adapter/nvrtc/__init__.py
+7
-7
tilelang/jit/adapter/nvrtc/adapter.py
tilelang/jit/adapter/nvrtc/adapter.py
+27
-25
tilelang/jit/adapter/nvrtc/libgen.py
tilelang/jit/adapter/nvrtc/libgen.py
+9
-11
tilelang/jit/adapter/nvrtc/wrapper.py
tilelang/jit/adapter/nvrtc/wrapper.py
+82
-64
tilelang/jit/adapter/torch/__init__.py
tilelang/jit/adapter/torch/__init__.py
+1
-1
tilelang/jit/adapter/torch/metal.py
tilelang/jit/adapter/torch/metal.py
+4
-8
tilelang/jit/adapter/tvm_ffi.py
tilelang/jit/adapter/tvm_ffi.py
+40
-40
tilelang/jit/adapter/utils.py
tilelang/jit/adapter/utils.py
+46
-66
tilelang/jit/adapter/wrapper.py
tilelang/jit/adapter/wrapper.py
+129
-97
tilelang/jit/execution_backend.py
tilelang/jit/execution_backend.py
+5
-2
tilelang/jit/kernel.py
tilelang/jit/kernel.py
+33
-38
tilelang/language/__init__.py
tilelang/language/__init__.py
+5
-1
tilelang/language/allocate.py
tilelang/language/allocate.py
+15
-26
tilelang/language/annotations.py
tilelang/language/annotations.py
+1
-0
tilelang/language/ast/__init__.py
tilelang/language/ast/__init__.py
+1
-0
No files found.
Too many changes to show.
To preserve performance only
426 of 426+
files are displayed.
Plain diff
Email patch
tilelang/jit/__init__.py
View file @
29051439
...
@@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
...
@@ -3,6 +3,7 @@ This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
kernel adapter using TVM.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -39,17 +40,16 @@ from tqdm.auto import tqdm
...
@@ -39,17 +40,16 @@ from tqdm.auto import tqdm
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
_P
=
ParamSpec
(
'
_P
'
)
_P
=
ParamSpec
(
"
_P
"
)
_KP
=
ParamSpec
(
'
_KP
'
)
_KP
=
ParamSpec
(
"
_KP
"
)
_T
=
TypeVar
(
'
_T
'
)
_T
=
TypeVar
(
"
_T
"
)
_Ret
=
TypeVar
(
'
_Ret
'
)
_Ret
=
TypeVar
(
"
_Ret
"
)
def
compile
(
def
compile
(
func
:
PrimFunc
[
_KP
,
_T
]
=
None
,
func
:
PrimFunc
[
_KP
,
_T
]
=
None
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
...
@@ -83,11 +83,9 @@ def compile(
...
@@ -83,11 +83,9 @@ def compile(
if
isinstance
(
compile_flags
,
str
):
if
isinstance
(
compile_flags
,
str
):
compile_flags
=
[
compile_flags
]
compile_flags
=
[
compile_flags
]
if
hasattr
(
func
,
'
out_idx_override
'
):
if
hasattr
(
func
,
"
out_idx_override
"
):
if
func
.
out_idx_override
is
not
None
and
out_idx
is
not
None
:
if
func
.
out_idx_override
is
not
None
and
out_idx
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
out_idx
=
func
.
out_idx_override
or
out_idx
out_idx
=
func
.
out_idx_override
or
out_idx
# This path is not a performance critical path, so we can afford to convert the target.
# This path is not a performance critical path, so we can afford to convert the target.
...
@@ -96,6 +94,7 @@ def compile(
...
@@ -96,6 +94,7 @@ def compile(
# Resolve execution backend (handles aliases, auto, validation per target)
# Resolve execution backend (handles aliases, auto, validation per target)
requested_backend
=
execution_backend
requested_backend
=
execution_backend
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
from
tilelang.jit.execution_backend
import
resolve_execution_backend
,
allowed_backends_for_target
execution_backend
=
resolve_execution_backend
(
requested_backend
,
target
)
execution_backend
=
resolve_execution_backend
(
requested_backend
,
target
)
if
verbose
:
if
verbose
:
allowed_now
=
allowed_backends_for_target
(
target
,
include_unavailable
=
False
)
allowed_now
=
allowed_backends_for_target
(
target
,
include_unavailable
=
False
)
...
@@ -119,17 +118,18 @@ def compile(
...
@@ -119,17 +118,18 @@ def compile(
)
)
def
par_compile
(
funcs
:
Iterable
[
PrimFunc
[
_KP
,
_T
]],
def
par_compile
(
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
funcs
:
Iterable
[
PrimFunc
[
_KP
,
_T
]],
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
out_idx
:
list
[
int
]
|
int
|
None
=
None
,
"torch"
]
=
"auto"
,
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
|
None
=
None
,
target_host
:
str
|
Target
|
None
=
None
,
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
num_workers
:
int
=
None
,
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
ignore_error
:
bool
=
False
,
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
"""
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
Parameters
...
@@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
...
@@ -151,7 +151,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
Additional keyword arguments to pass to the Compiler PassContext.
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
"""
with
concurrent
.
futures
.
ThreadPoolExecutor
(
num_workers
,
'
tl-par-comp
'
)
as
executor
:
with
concurrent
.
futures
.
ThreadPoolExecutor
(
num_workers
,
"
tl-par-comp
"
)
as
executor
:
futures
=
[]
futures
=
[]
future_map
=
{}
future_map
=
{}
for
i
,
func
in
enumerate
(
funcs
):
for
i
,
func
in
enumerate
(
funcs
):
...
@@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
...
@@ -170,9 +170,9 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
futures
.
append
(
future
)
futures
.
append
(
future
)
results
=
[...
for
_
in
futures
]
results
=
[...
for
_
in
futures
]
for
future
in
tqdm
(
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
),
total
=
len
(
futures
),
desc
=
"Parallel Compiling"
,
desc
=
"Parallel Compiling"
,
):
):
idx
=
future_map
[
future
]
idx
=
future_map
[
future
]
if
ignore_error
:
if
ignore_error
:
...
@@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
...
@@ -189,7 +189,7 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
@
dataclass
@
dataclass
class
JITImpl
(
Generic
[
_P
,
_KP
,
_T
,
_Ret
]):
class
JITImpl
(
Generic
[
_P
,
_KP
,
_T
,
_Ret
]):
'''
"""
Detailed Just-In-Time wrapper for TileLang programs.
Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the
This dataclass encapsulates the configuration and runtime helpers used by the
...
@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
...
@@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
PrimFunc and the resulting set is compiled in parallel via the
PrimFunc and the resulting set is compiled in parallel via the
module-level `par_compile` helper. Returns a list of JITKernel objects
module-level `par_compile` helper. Returns a list of JITKernel objects
in the same order as the provided configs.
in the same order as the provided configs.
'''
"""
out_idx
:
list
[
int
]
|
int
|
None
out_idx
:
list
[
int
]
|
int
|
None
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
execution_backend
:
Literal
[
"auto"
,
"dlpack"
,
"tvm_ffi"
,
"ctypes"
,
"cython"
,
"nvrtc"
,
"torch"
]
...
@@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
...
@@ -302,10 +302,9 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
assert
isinstance
(
tir
,
PrimFunc
),
f
"target function must be a PrimFunc but got
{
type
(
tir
)
}
"
assert
isinstance
(
tir
,
PrimFunc
),
f
"target function must be a PrimFunc but got
{
type
(
tir
)
}
"
return
tir
return
tir
def
par_compile
(
self
,
def
par_compile
(
configs
:
Iterable
[
dict
[
str
,
Any
]
|
tuple
[
str
,
Any
]],
self
,
configs
:
Iterable
[
dict
[
str
,
Any
]
|
tuple
[
str
,
Any
]],
num_workers
:
int
=
None
,
ignore_error
:
bool
=
False
num_workers
:
int
=
None
,
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
ignore_error
:
bool
=
False
)
->
list
[
JITKernel
[
_KP
,
_T
]]:
"""
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
Parameters
...
@@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
...
@@ -328,7 +327,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
"""
"""
configs
=
list
(
configs
)
configs
=
list
(
configs
)
funcs
=
[]
funcs
=
[]
for
cfg
in
tqdm
(
configs
,
desc
=
'
Elaborating
'
):
for
cfg
in
tqdm
(
configs
,
desc
=
"
Elaborating
"
):
if
isinstance
(
cfg
,
tuple
):
if
isinstance
(
cfg
,
tuple
):
funcs
.
append
(
self
.
get_tir
(
*
cfg
))
funcs
.
append
(
self
.
get_tir
(
*
cfg
))
elif
isinstance
(
cfg
,
dict
):
elif
isinstance
(
cfg
,
dict
):
...
@@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
...
@@ -345,7 +344,8 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
pass_configs
=
self
.
pass_configs
,
pass_configs
=
self
.
pass_configs
,
compile_flags
=
self
.
compile_flags
,
compile_flags
=
self
.
compile_flags
,
num_workers
=
num_workers
,
num_workers
=
num_workers
,
ignore_error
=
ignore_error
)
ignore_error
=
ignore_error
,
)
def
compile
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
def
compile
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
func
=
self
.
get_tir
(
*
args
,
**
kwargs
)
func
=
self
.
get_tir
(
*
args
,
**
kwargs
)
...
@@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
...
@@ -362,25 +362,25 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
if
self
.
debug_root_path
:
if
self
.
debug_root_path
:
if
isinstance
(
self
.
func
,
PrimFunc
):
if
isinstance
(
self
.
func
,
PrimFunc
):
func_name
=
self
.
func
.
attrs
[
'
global_symbol
'
]
func_name
=
self
.
func
.
attrs
[
"
global_symbol
"
]
else
:
else
:
func_name
=
getattr
(
self
.
func
,
'
__name__
'
,
'
jit_kernel
'
)
func_name
=
getattr
(
self
.
func
,
"
__name__
"
,
"
jit_kernel
"
)
kernel_file
=
f
'
tilelang_jit_kernel_
{
func_name
}
.c
'
kernel_file
=
f
"
tilelang_jit_kernel_
{
func_name
}
.c
"
program_file
=
f
'
tilelang_jit_program_
{
func_name
}
.py
'
program_file
=
f
"
tilelang_jit_program_
{
func_name
}
.py
"
makedirs
(
self
.
debug_root_path
,
exist_ok
=
True
)
makedirs
(
self
.
debug_root_path
,
exist_ok
=
True
)
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
'w'
)
as
f
:
with
open
(
path
.
join
(
self
.
debug_root_path
,
kernel_file
),
"w"
)
as
f
:
print
(
kernel_result
.
get_kernel_source
(),
file
=
f
)
print
(
kernel_result
.
get_kernel_source
(),
file
=
f
)
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
'w'
)
as
f
:
with
open
(
path
.
join
(
self
.
debug_root_path
,
program_file
),
"w"
)
as
f
:
print
(
func
.
script
(),
file
=
f
)
print
(
func
.
script
(),
file
=
f
)
return
kernel_result
return
kernel_result
def
parse_cache_key
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
def
parse_cache_key
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
return
self
.
func
.
func_annot
.
parse_key
(
*
args
,
**
kwargs
,
**
tune_params
)
return
self
.
func
.
func_annot
.
parse_key
(
*
args
,
**
kwargs
,
**
tune_params
)
else
:
else
:
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
key_args_tuple
=
args
key_args_tuple
=
args
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
key_kwargs_tuple
=
tuple
(
sorted
(
kwargs
.
items
()))
tuned_key_kwargs_tuple
=
tuple
(
sorted
(
tune_params
.
items
()))
tuned_key_kwargs_tuple
=
tuple
(
sorted
(
tune_params
.
items
()))
...
@@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
...
@@ -389,34 +389,31 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
def
convert_kernel_args
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
def
convert_kernel_args
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
if
isinstance
(
self
.
func
,
PrimFuncCreater
):
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
return
self
.
func
.
func_annot
.
convert_to_kernel_args
(
*
args
,
**
kwargs
,
**
tune_params
)
return
self
.
func
.
func_annot
.
convert_to_kernel_args
(
*
args
,
**
kwargs
,
**
tune_params
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater."
)
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater."
)
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_Ret
:
# Separate out the tuning parameters from the user's kwargs
# Separate out the tuning parameters from the user's kwargs
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments
=
kwargs
.
pop
(
'
__return_compile_arguments
'
,
False
)
return_compile_arguments
=
kwargs
.
pop
(
"
__return_compile_arguments
"
,
False
)
if
return_compile_arguments
:
if
return_compile_arguments
:
logger
.
warning
(
logger
.
warning
(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
compile_args
=
{
compile_args
=
{
'
out_idx
'
:
self
.
out_idx
,
"
out_idx
"
:
self
.
out_idx
,
'
execution_backend
'
:
self
.
execution_backend
,
"
execution_backend
"
:
self
.
execution_backend
,
'
target
'
:
self
.
target
,
"
target
"
:
self
.
target
,
'
target_host
'
:
self
.
target_host
,
"
target_host
"
:
self
.
target_host
,
'
verbose
'
:
self
.
verbose
,
"
verbose
"
:
self
.
verbose
,
'
pass_configs
'
:
self
.
pass_configs
,
"
pass_configs
"
:
self
.
pass_configs
,
'
compile_flags
'
:
self
.
compile_flags
,
"
compile_flags
"
:
self
.
compile_flags
,
}
}
return
compile_args
return
compile_args
key
=
self
.
parse_cache_key
(
*
args
,
**
kwargs
)
key
=
self
.
parse_cache_key
(
*
args
,
**
kwargs
)
tune_params
=
kwargs
.
pop
(
'
__tune_params
'
,
{})
tune_params
=
kwargs
.
pop
(
"
__tune_params
"
,
{})
kernel
=
self
.
_kernel_cache
.
get
(
key
,
None
)
kernel
=
self
.
_kernel_cache
.
get
(
key
,
None
)
if
kernel
is
None
:
if
kernel
is
None
:
...
@@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr
...
@@ -434,8 +431,7 @@ ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvr
@
overload
@
overload
def
jit
(
func
:
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]])
->
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]:
def
jit
(
func
:
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]])
->
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]:
...
...
@
overload
@
overload
...
@@ -448,22 +444,22 @@ def jit(
...
@@ -448,22 +444,22 @@ def jit(
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
)
->
Callable
[[
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]]],
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]]:
)
->
Callable
[[
Callable
[
_P
,
PrimFunc
[
_KP
,
_T
]]],
JITImpl
[
_P
,
_KP
,
_T
,
JITKernel
[
_KP
,
_T
]]]:
...
...
def
jit
(
# This is the new public interface
def
jit
(
# This is the new public interface
func
:
Callable
[
_P
,
_T
]
|
PrimFunc
|
None
=
None
,
func
:
Callable
[
_P
,
_T
]
|
PrimFunc
|
None
=
None
,
*
,
# Indicates subsequent arguments are keyword-only
*
,
# Indicates subsequent arguments are keyword-only
out_idx
:
Any
=
None
,
out_idx
:
Any
=
None
,
target
:
str
|
Target
=
"auto"
,
target
:
str
|
Target
=
"auto"
,
target_host
:
str
|
Target
=
None
,
target_host
:
str
|
Target
=
None
,
execution_backend
:
ExecutionBackend
=
"auto"
,
execution_backend
:
ExecutionBackend
=
"auto"
,
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
):
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
):
"""
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
Just-In-Time (JIT) compiler decorator for TileLang functions.
...
@@ -516,7 +512,8 @@ def jit( # This is the new public interface
...
@@ -516,7 +512,8 @@ def jit( # This is the new public interface
compile_flags
=
compile_flags
,
compile_flags
=
compile_flags
,
func_source
=
inspect
.
getsource
(
orig_func
),
func_source
=
inspect
.
getsource
(
orig_func
),
signature
=
inspect
.
signature
(
orig_func
),
signature
=
inspect
.
signature
(
orig_func
),
lazy_jit
=
False
)
lazy_jit
=
False
,
)
if
func
is
not
None
:
if
func
is
not
None
:
return
decorator
(
func
)
return
decorator
(
func
)
...
@@ -525,8 +522,7 @@ def jit( # This is the new public interface
...
@@ -525,8 +522,7 @@ def jit( # This is the new public interface
@
overload
@
overload
def
lazy_jit
(
func
:
Callable
[
_KP
,
_T
])
->
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]:
def
lazy_jit
(
func
:
Callable
[
_KP
,
_T
])
->
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]:
...
...
@
overload
@
overload
...
@@ -539,9 +535,8 @@ def lazy_jit(
...
@@ -539,9 +535,8 @@ def lazy_jit(
verbose
:
bool
=
False
,
verbose
:
bool
=
False
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
)
->
Callable
[[
Callable
[
_KP
,
_T
]],
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]]:
)
->
Callable
[[
Callable
[
_KP
,
_T
]],
JITImpl
[
_KP
,
_KP
,
_T
,
_T
]]:
...
...
def
lazy_jit
(
def
lazy_jit
(
...
@@ -555,7 +550,6 @@ def lazy_jit(
...
@@ -555,7 +550,6 @@ def lazy_jit(
debug_root_path
:
str
|
None
=
None
,
debug_root_path
:
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
compile_flags
:
list
[
str
]
|
str
|
None
=
None
,
):
):
if
isinstance
(
compile_flags
,
str
):
if
isinstance
(
compile_flags
,
str
):
compile_flags
=
[
compile_flags
]
compile_flags
=
[
compile_flags
]
...
@@ -567,7 +561,8 @@ def lazy_jit(
...
@@ -567,7 +561,8 @@ def lazy_jit(
verbose
=
verbose
,
verbose
=
verbose
,
pass_configs
=
pass_configs
,
pass_configs
=
pass_configs
,
debug_root_path
=
debug_root_path
,
debug_root_path
=
debug_root_path
,
compile_flags
=
compile_flags
)
compile_flags
=
compile_flags
,
)
def
decorator
(
func
:
Callable
[
_P
,
_T
]):
def
decorator
(
func
:
Callable
[
_P
,
_T
]):
pf
:
PrimFunc
[
_P
,
_T
]
|
PrimFuncCreater
[
_P
,
_T
]
=
prim_func
(
func
,
generator
=
True
)
pf
:
PrimFunc
[
_P
,
_T
]
|
PrimFuncCreater
[
_P
,
_T
]
=
prim_func
(
func
,
generator
=
True
)
...
@@ -576,10 +571,7 @@ def lazy_jit(
...
@@ -576,10 +571,7 @@ def lazy_jit(
# return compile(pf, **compile_args)
# return compile(pf, **compile_args)
# else:
# else:
return
JITImpl
(
return
JITImpl
(
func
=
pf
,
func
=
pf
,
**
compile_args
,
func_source
=
inspect
.
getsource
(
pf
.
orig_func
),
signature
=
inspect
.
signature
(
pf
.
orig_func
),
lazy_jit
=
True
**
compile_args
,
)
func_source
=
inspect
.
getsource
(
pf
.
orig_func
),
signature
=
inspect
.
signature
(
pf
.
orig_func
),
lazy_jit
=
True
)
return
decorator
(
func
)
if
func
is
not
None
else
decorator
return
decorator
(
func
)
if
func
is
not
None
else
decorator
tilelang/jit/adapter/base.py
View file @
29051439
"""The profiler and convert to torch utils"""
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
...
@@ -8,7 +9,6 @@ import torch
...
@@ -8,7 +9,6 @@ import torch
class
BaseKernelAdapter
(
ABC
):
class
BaseKernelAdapter
(
ABC
):
func
:
Callable
|
None
=
None
func
:
Callable
|
None
=
None
def
__init__
(
self
,
mod
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
])
->
None
:
def
__init__
(
self
,
mod
,
params
:
list
[
KernelParam
],
result_idx
:
list
[
int
])
->
None
:
...
@@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC):
...
@@ -24,18 +24,14 @@ class BaseKernelAdapter(ABC):
result_idx
=
[]
result_idx
=
[]
elif
isinstance
(
result_idx
,
int
):
elif
isinstance
(
result_idx
,
int
):
if
result_idx
>
len
(
params
)
or
result_idx
<
-
len
(
params
):
if
result_idx
>
len
(
params
)
or
result_idx
<
-
len
(
params
):
raise
ValueError
(
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
if
result_idx
<
0
:
if
result_idx
<
0
:
result_idx
=
len
(
params
)
+
result_idx
result_idx
=
len
(
params
)
+
result_idx
result_idx
=
[
result_idx
]
result_idx
=
[
result_idx
]
elif
isinstance
(
result_idx
,
list
):
elif
isinstance
(
result_idx
,
list
):
for
i
,
idx
in
enumerate
(
result_idx
):
for
i
,
idx
in
enumerate
(
result_idx
):
if
idx
>=
len
(
params
)
or
idx
<
-
len
(
params
):
if
idx
>=
len
(
params
)
or
idx
<
-
len
(
params
):
raise
ValueError
(
raise
ValueError
(
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
f
"result_idx should be an integer between
{
-
len
(
params
)
-
1
}
and
{
len
(
params
)
-
1
}
"
)
if
idx
<
0
:
if
idx
<
0
:
result_idx
[
i
]
=
len
(
params
)
+
idx
result_idx
[
i
]
=
len
(
params
)
+
idx
else
:
else
:
...
...
tilelang/jit/adapter/ctypes/adapter.py
View file @
29051439
"""The profiler and convert to torch utils"""
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
__future__
import
annotations
import
torch
import
torch
from
..base
import
BaseKernelAdapter
from
..base
import
BaseKernelAdapter
...
@@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter):
...
@@ -41,18 +42,20 @@ class CtypesKernelAdapter(BaseKernelAdapter):
param_dtypes
:
list
[
torch
.
dtype
]
|
None
=
None
# Cache for parameter dtypes
param_dtypes
:
list
[
torch
.
dtype
]
|
None
=
None
# Cache for parameter dtypes
param_shapes
:
list
[
list
]
|
None
=
None
# Cache for parameter shapes
param_shapes
:
list
[
list
]
|
None
=
None
# Cache for parameter shapes
def
__init__
(
self
,
def
__init__
(
params
:
list
[
TensorType
],
self
,
result_idx
:
list
[
int
],
params
:
list
[
TensorType
],
target
:
str
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
device_kernel_source
:
str
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
"""Initialize the adapter with the given TIR function or module.
Args:
Args:
...
@@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter):
...
@@ -109,17 +112,19 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
self
.
_post_init
()
@
classmethod
@
classmethod
def
from_database
(
cls
,
def
from_database
(
params
:
list
[
TensorType
],
cls
,
result_idx
:
list
[
int
],
params
:
list
[
TensorType
],
target
:
str
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
,
host_kernel_source
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_kernel_source
:
str
,
host_kernel_source
:
str
,
kernel_lib_path
:
str
,
device_kernel_source
:
str
,
verbose
:
bool
=
False
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
@@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter):
...
@@ -175,15 +180,13 @@ class CtypesKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
(
shape
not
in
params
)):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
(
stride
not
in
params
)):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
return
dynamic_symbolic_map
...
@@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
...
@@ -192,9 +195,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
"""
ctypes_args
=
[
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
self
.
lib
.
call
(
*
ctypes_args
)
self
.
lib
.
call
(
*
ctypes_args
)
...
@@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
...
@@ -288,7 +289,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
@
property
@
property
def
is_dynamic
(
self
):
def
is_dynamic
(
self
):
"""Indicates whether the kernel handles dynamic shapes."""
"""Indicates whether the kernel handles dynamic shapes."""
return
(
self
.
dynamic_symbolic_map
is
not
None
and
len
(
self
.
dynamic_symbolic_map
)
>
0
)
return
self
.
dynamic_symbolic_map
is
not
None
and
len
(
self
.
dynamic_symbolic_map
)
>
0
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
False
):
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
False
):
"""Returns the source code of the compiled kernel."""
"""Returns the source code of the compiled kernel."""
...
...
tilelang/jit/adapter/cython/adapter.py
View file @
29051439
"""The profiler and convert to torch utils"""
"""The profiler and convert to torch utils"""
from
__future__
import
annotations
from
__future__
import
annotations
import
ctypes
import
ctypes
import
logging
import
logging
...
@@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -70,17 +71,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
# Pass configs for the compiler
# Pass configs for the compiler
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
def
__init__
(
params
:
list
[
KernelParam
],
self
,
result_idx
:
list
[
int
],
params
:
list
[
KernelParam
],
target
:
str
|
Target
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
|
Target
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
verbose
:
bool
=
False
,
device_kernel_source
:
str
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
"""Initialize the adapter with the given TIR function or module.
Args:
Args:
...
@@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -130,7 +133,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
self
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
result
=
self
.
lib
.
init
()
result
=
self
.
lib
.
init
()
if
result
!=
0
:
if
result
!=
0
:
error_msg
=
self
.
lib
.
get_last_error
().
decode
(
'
utf-8
'
)
error_msg
=
self
.
lib
.
get_last_error
().
decode
(
"
utf-8
"
)
error_msg
+=
f
"
\n
{
self
.
lib_code
}
"
error_msg
+=
f
"
\n
{
self
.
lib_code
}
"
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
...
@@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -145,17 +148,19 @@ class CythonKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
self
.
_post_init
()
@
classmethod
@
classmethod
def
from_database
(
cls
,
def
from_database
(
params
:
list
[
TensorType
],
cls
,
result_idx
:
list
[
int
],
params
:
list
[
TensorType
],
target
:
str
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
,
host_kernel_source
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_kernel_source
:
str
,
host_kernel_source
:
str
,
kernel_lib_path
:
str
,
device_kernel_source
:
str
,
verbose
:
bool
=
False
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
@@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -190,11 +195,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
adapter
.
lib
.
get_last_error
.
restype
=
ctypes
.
c_char_p
result
=
adapter
.
lib
.
init
()
result
=
adapter
.
lib
.
init
()
if
result
!=
0
:
if
result
!=
0
:
error_msg
=
adapter
.
lib
.
get_last_error
().
decode
(
'
utf-8
'
)
error_msg
=
adapter
.
lib
.
get_last_error
().
decode
(
"
utf-8
"
)
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
raise
RuntimeError
(
f
"Initialization failed:
{
error_msg
}
"
)
adapter
.
cython_wrapper
=
CythonKernelWrapper
(
adapter
.
result_idx
,
adapter
.
params
,
adapter
.
cython_wrapper
=
CythonKernelWrapper
(
adapter
.
result_idx
,
adapter
.
params
,
adapter
.
lib
)
adapter
.
lib
)
adapter
.
cython_wrapper
.
set_dynamic_symbolic_map
(
adapter
.
dynamic_symbolic_map
)
adapter
.
cython_wrapper
.
set_dynamic_symbolic_map
(
adapter
.
dynamic_symbolic_map
)
adapter
.
cython_wrapper
.
set_buffer_dtype_map
(
adapter
.
buffer_dtype_map
)
adapter
.
cython_wrapper
.
set_buffer_dtype_map
(
adapter
.
buffer_dtype_map
)
adapter
.
cython_wrapper
.
set_static_shape_map
(
adapter
.
static_shape_map
)
adapter
.
cython_wrapper
.
set_static_shape_map
(
adapter
.
static_shape_map
)
...
@@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -221,15 +225,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
(
shape
not
in
params
)):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
(
stride
not
in
params
)):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
return
dynamic_symbolic_map
...
@@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -259,14 +261,13 @@ class CythonKernelAdapter(BaseKernelAdapter):
params
=
func
.
params
params
=
func
.
params
ptr_map
=
{}
ptr_map
=
{}
for
i
,
param
in
enumerate
(
params
):
for
i
,
param
in
enumerate
(
params
):
if
param
.
dtype
==
'
handle
'
:
if
param
.
dtype
==
"
handle
"
:
ptr_map
[
i
]
=
param
.
name
ptr_map
[
i
]
=
param
.
name
return
ptr_map
return
ptr_map
def
_process_static_buffer_infos
(
self
)
->
\
def
_process_static_buffer_infos
(
tuple
[
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
self
,
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
)
->
tuple
[
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
dict
[
tir
.
Var
,
tuple
[
int
,
list
[
tuple
[
int
,
int
]]]],
list
[
tuple
[
tir
.
Var
]]]:
list
[
tuple
[
tir
.
Var
]]]:
"""Extract information about static shapes from the TIR function.
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
Maps buffer variables to their corresponding static shapes.
...
@@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -332,9 +333,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
"""
"""
ctypes_args
=
[
ctypes_args
=
[
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes
.
c_void_p
(
arr
.
data_ptr
())
if
not
isinstance
(
arr
,
int
)
else
arr
for
arr
in
args
]
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
ctypes_args
.
append
(
ctypes
.
c_void_p
(
stream
))
self
.
lib
.
call
(
*
ctypes_args
)
self
.
lib
.
call
(
*
ctypes_args
)
...
@@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
...
@@ -349,9 +348,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
skip_tensor_validation: Whether to skip tensor attributes validation which
skip_tensor_validation: Whether to skip tensor attributes validation which
includes shape, dtype, device, etc.
includes shape, dtype, device, etc.
"""
"""
return
self
.
cython_wrapper
.
forward
([
*
args
],
return
self
.
cython_wrapper
.
forward
([
*
args
],
stream
=
stream
,
skip_tensor_validation
=
skip_tensor_validation
)
stream
=
stream
,
skip_tensor_validation
=
skip_tensor_validation
)
return
lambda_forward
return
lambda_forward
...
...
tilelang/jit/adapter/libgen.py
View file @
29051439
...
@@ -55,6 +55,7 @@ class LibraryGenerator:
...
@@ -55,6 +55,7 @@ class LibraryGenerator:
verbose
=
self
.
verbose
verbose
=
self
.
verbose
if
is_cuda_target
(
target
):
if
is_cuda_target
(
target
):
from
tilelang.env
import
CUTLASS_INCLUDE_DIR
from
tilelang.env
import
CUTLASS_INCLUDE_DIR
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
# noqa: SIM115
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
# noqa: SIM115
target_arch
=
get_target_arch
(
get_target_compute_version
(
target
))
target_arch
=
get_target_arch
(
get_target_compute_version
(
target
))
libpath
=
src
.
name
.
replace
(
".cu"
,
".so"
)
libpath
=
src
.
name
.
replace
(
".cu"
,
".so"
)
...
@@ -65,15 +66,12 @@ class LibraryGenerator:
...
@@ -65,15 +66,12 @@ class LibraryGenerator:
"TL_ENABLE_FAST_MATH"
,
"TL_ENABLE_FAST_MATH"
,
"0.1.7"
,
"0.1.7"
,
)
)
enable_fast_math
=
not
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_DISABLE_FAST_MATH
,
enable_fast_math
=
not
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_DISABLE_FAST_MATH
,
True
)
True
)
else
:
else
:
enable_fast_math
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_FAST_MATH
,
False
)
enable_fast_math
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_FAST_MATH
,
False
)
ptxas_usage_level
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_PTXAS_REGISTER_USAGE_LEVEL
,
ptxas_usage_level
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_PTXAS_REGISTER_USAGE_LEVEL
,
None
)
None
)
verbose_ptxas_output
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_PTXAS_VERBOSE_OUTPUT
,
False
)
verbose_ptxas_output
=
self
.
pass_configs
.
get
(
PassConfigKey
.
TL_ENABLE_PTXAS_VERBOSE_OUTPUT
,
False
)
command
=
[
command
=
[
get_nvcc_compiler
(),
get_nvcc_compiler
(),
...
@@ -102,6 +100,7 @@ class LibraryGenerator:
...
@@ -102,6 +100,7 @@ class LibraryGenerator:
elif
is_hip_target
(
target
):
elif
is_hip_target
(
target
):
from
tilelang.env
import
COMPOSABLE_KERNEL_INCLUDE_DIR
from
tilelang.env
import
COMPOSABLE_KERNEL_INCLUDE_DIR
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
rocm_path
=
find_rocm_path
()
rocm_path
=
find_rocm_path
()
...
@@ -119,6 +118,7 @@ class LibraryGenerator:
...
@@ -119,6 +118,7 @@ class LibraryGenerator:
]
]
elif
is_cpu_target
(
target
):
elif
is_cpu_target
(
target
):
from
tilelang.contrib.cc
import
get_cplus_compiler
from
tilelang.contrib.cc
import
get_cplus_compiler
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cpp"
,
delete
=
False
)
# noqa: SIM115
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
libpath
=
src
.
name
.
replace
(
".cpp"
,
".so"
)
...
@@ -134,9 +134,7 @@ class LibraryGenerator:
...
@@ -134,9 +134,7 @@ class LibraryGenerator:
]
]
if
self
.
compile_flags
:
if
self
.
compile_flags
:
command
+=
[
command
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
command
]
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
command
]
command
+=
[
"-o"
,
libpath
]
command
+=
[
"-o"
,
libpath
]
...
@@ -151,8 +149,7 @@ class LibraryGenerator:
...
@@ -151,8 +149,7 @@ class LibraryGenerator:
raise
RuntimeError
(
f
"Compile kernel failed because of
{
e
}
"
)
from
e
raise
RuntimeError
(
f
"Compile kernel failed because of
{
e
}
"
)
from
e
if
ret
.
returncode
!=
0
:
if
ret
.
returncode
!=
0
:
raise
RuntimeError
(
f
"Compilation Failed!
{
command
}
"
raise
RuntimeError
(
f
"Compilation Failed!
{
command
}
\n
{
self
.
lib_code
}
"
)
f
"
\n
{
self
.
lib_code
}
"
)
self
.
srcpath
=
src
.
name
self
.
srcpath
=
src
.
name
self
.
libpath
=
libpath
self
.
libpath
=
libpath
...
...
tilelang/jit/adapter/nvrtc/__init__.py
View file @
29051439
...
@@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
...
@@ -5,22 +5,22 @@ This module provides runtime compilation support using NVIDIA's NVRTC API.
import
logging
import
logging
__all__
=
[
__all__
=
[
"NVRTCKernelAdapter"
,
"TLNVRTCSourceWrapper"
,
"NVRTCLibraryGenerator"
,
"is_nvrtc_available"
,
"check_nvrtc_available"
]
'NVRTCKernelAdapter'
,
'TLNVRTCSourceWrapper'
,
'NVRTCLibraryGenerator'
,
'is_nvrtc_available'
,
'check_nvrtc_available'
]
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Check if cuda-python is available
# Check if cuda-python is available
is_nvrtc_available
=
False
is_nvrtc_available
=
False
NVRTC_UNAVAILABLE_MESSAGE
=
(
"cuda-python is not available, NVRTC backend cannot be used. "
NVRTC_UNAVAILABLE_MESSAGE
=
(
"Please install cuda-python via `pip install cuda-python` "
"cuda-python is not available, NVRTC backend cannot be used. "
"if you want to use the NVRTC backend."
)
"Please install cuda-python via `pip install cuda-python` "
"if you want to use the NVRTC backend."
)
try
:
try
:
import
cuda.bindings.driver
as
cuda
# noqa: F401
import
cuda.bindings.driver
as
cuda
# noqa: F401
import
cuda.bindings.nvrtc
as
nvrtc
# noqa: F401
import
cuda.bindings.nvrtc
as
nvrtc
# noqa: F401
is_nvrtc_available
=
True
is_nvrtc_available
=
True
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
debug
(
f
"cuda-python import failed:
{
e
}
"
)
logger
.
debug
(
f
"cuda-python import failed:
{
e
}
"
)
...
...
tilelang/jit/adapter/nvrtc/adapter.py
View file @
29051439
...
@@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
...
@@ -27,18 +27,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
pymodule
=
None
pymodule
=
None
kernels
=
{}
kernels
=
{}
def
__init__
(
self
,
def
__init__
(
params
:
list
[
KernelParam
],
self
,
result_idx
:
list
[
int
],
params
:
list
[
KernelParam
],
target
:
str
|
Target
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
|
Target
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
verbose
:
bool
=
False
,
device_kernel_source
:
str
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
check_nvrtc_available
()
check_nvrtc_available
()
self
.
params
=
params
self
.
params
=
params
...
@@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
...
@@ -92,17 +93,19 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self
.
_post_init
()
self
.
_post_init
()
@
classmethod
@
classmethod
def
from_database
(
cls
,
def
from_database
(
params
:
list
[
KernelParam
],
cls
,
result_idx
:
list
[
int
],
params
:
list
[
KernelParam
],
target
:
str
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
,
host_kernel_source
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_kernel_source
:
str
,
host_kernel_source
:
str
,
kernel_lib_path
:
str
,
device_kernel_source
:
str
,
verbose
:
bool
=
False
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
@@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
...
@@ -183,8 +186,7 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
return
self
.
host_func
return
self
.
host_func
def
_forward_from_prebuild_lib
(
self
,
*
args
,
stream
:
int
|
None
=
None
):
def
_forward_from_prebuild_lib
(
self
,
*
args
,
stream
:
int
|
None
=
None
):
"""Low-level function to call the compiled CUDA kernel.
"""Low-level function to call the compiled CUDA kernel."""
"""
return
self
.
pymodule
.
call
(
self
.
kernels
,
*
args
,
stream
=
stream
)
return
self
.
pymodule
.
call
(
self
.
kernels
,
*
args
,
stream
=
stream
)
def
_wrap_forward_from_prebuild_lib
(
self
,
*
ins
:
list
[
torch
.
Tensor
],
stream
:
int
|
None
=
None
):
def
_wrap_forward_from_prebuild_lib
(
self
,
*
ins
:
list
[
torch
.
Tensor
],
stream
:
int
|
None
=
None
):
...
...
tilelang/jit/adapter/nvrtc/libgen.py
View file @
29051439
...
@@ -13,6 +13,7 @@ Key responsibilities:
...
@@ -13,6 +13,7 @@ Key responsibilities:
- Load compiled cubin and extract kernel handles
- Load compiled cubin and extract kernel handles
- Manage library lifecycle (load/unload)
- Manage library lifecycle (load/unload)
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
import
importlib
import
importlib
import
logging
import
logging
...
@@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
...
@@ -56,6 +57,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
culib: CUDA library handle (CUlibrary)
culib: CUDA library handle (CUlibrary)
pymodule: Imported Python module containing call() function
pymodule: Imported Python module containing call() function
"""
"""
host_func
:
str
|
None
=
None
host_func
:
str
|
None
=
None
culib
:
cuda
.
CUlibrary
|
None
=
None
culib
:
cuda
.
CUlibrary
|
None
=
None
pymodule
:
ModuleType
|
None
=
None
pymodule
:
ModuleType
|
None
=
None
...
@@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator):
...
@@ -131,10 +133,10 @@ class NVRTCLibraryGenerator(LibraryGenerator):
ctx
=
cuda
.
cuCtxGetCurrent
()[
1
]
ctx
=
cuda
.
cuCtxGetCurrent
()[
1
]
if
cuda
.
cuCtxGetApiVersion
(
ctx
)[
0
]
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
if
cuda
.
cuCtxGetApiVersion
(
ctx
)[
0
]
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
import
torch
import
torch
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
result
,
self
.
culib
=
cuda
.
cuLibraryLoadFromFile
(
result
,
self
.
culib
=
cuda
.
cuLibraryLoadFromFile
(
bytes
(
lib_path
,
"utf-8"
),
[],
[],
0
,
[],
[],
0
)
bytes
(
lib_path
,
"utf-8"
),
[],
[],
0
,
[],
[],
0
)
if
result
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
if
result
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
raise
RuntimeError
(
f
"Failed to load library:
{
lib_path
}
, error:
{
result
}
"
)
raise
RuntimeError
(
f
"Failed to load library:
{
lib_path
}
, error:
{
result
}
"
)
...
@@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator):
...
@@ -164,7 +166,8 @@ class NVRTCLibraryGenerator(LibraryGenerator):
target
=
self
.
target
target
=
self
.
target
verbose
=
self
.
verbose
verbose
=
self
.
verbose
if
is_cuda_target
(
target
):
if
is_cuda_target
(
target
):
from
tilelang.env
import
(
CUDA_HOME
,
CUTLASS_INCLUDE_DIR
,
TILELANG_TEMPLATE_PATH
)
from
tilelang.env
import
CUDA_HOME
,
CUTLASS_INCLUDE_DIR
,
TILELANG_TEMPLATE_PATH
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
src
=
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".cu"
,
delete
=
False
)
libpath
=
src
.
name
.
replace
(
".cu"
,
".cubin"
)
libpath
=
src
.
name
.
replace
(
".cu"
,
".cubin"
)
...
@@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator):
...
@@ -195,13 +198,9 @@ class NVRTCLibraryGenerator(LibraryGenerator):
f
"-D__CUDACC_VER_MAJOR__=
{
__CUDACC_VER_MAJOR__
}
"
,
f
"-D__CUDACC_VER_MAJOR__=
{
__CUDACC_VER_MAJOR__
}
"
,
]
]
if
self
.
compile_flags
:
if
self
.
compile_flags
:
options
+=
[
options
+=
[
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
options
]
item
for
flag
in
self
.
compile_flags
for
item
in
flag
.
split
()
if
item
not
in
options
]
cubin_bytes
=
compile_cuda
(
cubin_bytes
=
compile_cuda
(
self
.
lib_code
,
target_format
=
"cubin"
,
options
=
options
,
verbose
=
verbose
)
self
.
lib_code
,
target_format
=
"cubin"
,
options
=
options
,
verbose
=
verbose
)
with
open
(
libpath
,
"wb"
)
as
f
:
with
open
(
libpath
,
"wb"
)
as
f
:
f
.
write
(
cubin_bytes
)
f
.
write
(
cubin_bytes
)
...
@@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
...
@@ -212,8 +211,7 @@ class NVRTCLibraryGenerator(LibraryGenerator):
self
.
libpath
=
libpath
self
.
libpath
=
libpath
self
.
pypath
=
src
.
name
.
replace
(
".cu"
,
".py"
)
self
.
pypath
=
src
.
name
.
replace
(
".cu"
,
".py"
)
if
self
.
host_func
is
None
:
if
self
.
host_func
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"Host function is not set, please call update_host_func() first."
)
"Host function is not set, please call update_host_func() first."
)
with
open
(
self
.
pypath
,
"w"
)
as
f
:
with
open
(
self
.
pypath
,
"w"
)
as
f
:
f
.
write
(
self
.
host_func
)
f
.
write
(
self
.
host_func
)
else
:
else
:
...
...
tilelang/jit/adapter/nvrtc/wrapper.py
View file @
29051439
...
@@ -12,6 +12,7 @@ Key design:
...
@@ -12,6 +12,7 @@ Key design:
- Dict-based deduplication ensures TMA descriptors created only once
- Dict-based deduplication ensures TMA descriptors created only once
- Generates pure Python using cuda.bindings.driver for zero C++ dependency
- Generates pure Python using cuda.bindings.driver for zero C++ dependency
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
,
ClassVar
from
typing
import
Any
,
ClassVar
...
@@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit
...
@@ -21,8 +22,7 @@ from tvm.tir.stmt_functor import post_order_visit
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
from
tilelang.jit.adapter.wrapper
import
TLCUDASourceWrapper
from
tilelang.jit.adapter.wrapper
import
TLCUDASourceWrapper
from
tilelang.jit.adapter.utils
import
(
match_declare_kernel
,
pythonic_expr
,
from
tilelang.jit.adapter.utils
import
match_declare_kernel
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
parse_function_call_args
,
parse_tma_descriptor_args
)
PREDEF_HOST_FUNC_PY
=
"""
PREDEF_HOST_FUNC_PY
=
"""
from cuda.bindings.driver import (
from cuda.bindings.driver import (
...
@@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -235,13 +235,15 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
_generated_host_func
:
str
|
None
=
None
_generated_host_func
:
str
|
None
=
None
def
__init__
(
self
,
def
__init__
(
scheduled_ir_module
:
IRModule
,
self
,
source
:
str
,
scheduled_ir_module
:
IRModule
,
target
:
Target
,
source
:
str
,
device_mod
:
IRModule
|
None
=
None
,
target
:
Target
,
host_mod
:
IRModule
|
None
=
None
,
device_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
"""Initialize NVRTC wrapper with compiled IR modules.
"""Initialize NVRTC wrapper with compiled IR modules.
Args:
Args:
...
@@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -303,15 +305,16 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
for
param
in
self
.
prim_func
.
params
:
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
function_args
.
append
(
"name"
:
buffer
.
data
.
name
,
{
"type"
:
"ctypes.c_void_p"
,
"name"
:
buffer
.
data
.
name
,
})
"type"
:
"ctypes.c_void_p"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
...
@@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -359,9 +362,9 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return
(
f
"
{
name
}
.data_ptr()"
,
arg_type
)
return
(
f
"
{
name
}
.data_ptr()"
,
arg_type
)
return
(
name
,
arg_type
)
return
(
name
,
arg_type
)
call_args
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
call_args
=
parse_function_call_args
(
desc_name_map
,
desc_name_var_map
,
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
,
transform_nvrtc_arg
transform_nvrtc_arg
)
)
for
arg_name
,
arg_type
in
call_args
:
for
arg_name
,
arg_type
in
call_args
:
if
arg_type
==
"ctypes.c_void_p"
:
if
arg_type
==
"ctypes.c_void_p"
:
...
@@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -369,26 +372,28 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
break
break
# Store kernel info for second pass
# Store kernel info for second pass
kernel_info_list
.
append
({
kernel_info_list
.
append
(
'function_name'
:
function_name
,
{
'block_info'
:
block_info
,
"function_name"
:
function_name
,
'grid_info'
:
grid_info
,
"block_info"
:
block_info
,
'dynamic_smem_buf'
:
dynamic_smem_buf
,
"grid_info"
:
grid_info
,
'call_args'
:
call_args
,
"dynamic_smem_buf"
:
dynamic_smem_buf
,
'device_index'
:
device_index
,
"call_args"
:
call_args
,
})
"device_index"
:
device_index
,
}
)
# Generate TMA descriptor initialization code once for all kernels
# Generate TMA descriptor initialization code once for all kernels
kernel_launch_code
+=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
kernel_launch_code
+=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
# Second pass: generate kernel launch code for each kernel
# Second pass: generate kernel launch code for each kernel
for
kernel_info
in
kernel_info_list
:
for
kernel_info
in
kernel_info_list
:
function_name
=
kernel_info
[
'
function_name
'
]
function_name
=
kernel_info
[
"
function_name
"
]
block_info
=
kernel_info
[
'
block_info
'
]
block_info
=
kernel_info
[
"
block_info
"
]
grid_info
=
kernel_info
[
'
grid_info
'
]
grid_info
=
kernel_info
[
"
grid_info
"
]
dynamic_smem_buf
=
kernel_info
[
'
dynamic_smem_buf
'
]
dynamic_smem_buf
=
kernel_info
[
"
dynamic_smem_buf
"
]
call_args
=
kernel_info
[
'
call_args
'
]
call_args
=
kernel_info
[
"
call_args
"
]
device_index
=
kernel_info
[
'
device_index
'
]
device_index
=
kernel_info
[
"
device_index
"
]
arg_names
=
", "
.
join
([
arg
[
0
]
for
arg
in
call_args
])
arg_names
=
", "
.
join
([
arg
[
0
]
for
arg
in
call_args
])
arg_types
=
", "
.
join
([
arg
[
1
]
for
arg
in
call_args
])
arg_types
=
", "
.
join
([
arg
[
1
]
for
arg
in
call_args
])
...
@@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -399,23 +404,26 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
kernel_launch_code
+=
init_l2_persistent_map
kernel_launch_code
+=
init_l2_persistent_map
# Generate kernel launch code
# Generate kernel launch code
kernel_launch_code
+=
KERNEL_LAUNCH_FUNC_PY
.
format
(
function_name
,
kernel_launch_code
+=
KERNEL_LAUNCH_FUNC_PY
.
format
(
self
.
_pythonic_expr
(
grid_info
[
0
]),
function_name
,
self
.
_pythonic_expr
(
grid_info
[
1
]),
self
.
_pythonic_expr
(
grid_info
[
0
]),
self
.
_pythonic_expr
(
grid_info
[
2
]),
self
.
_pythonic_expr
(
grid_info
[
1
]),
self
.
_pythonic_expr
(
block_info
[
0
]),
self
.
_pythonic_expr
(
grid_info
[
2
]),
self
.
_pythonic_expr
(
block_info
[
1
]),
self
.
_pythonic_expr
(
block_info
[
0
]),
self
.
_pythonic_expr
(
block_info
[
2
]),
self
.
_pythonic_expr
(
block_info
[
1
]),
smem_str
,
arg_names
,
arg_types
,
self
.
_pythonic_expr
(
block_info
[
2
]),
device_index
)
smem_str
,
arg_names
,
arg_types
,
device_index
,
)
# Reset L2 persistent map after all kernel execution
# Reset L2 persistent map after all kernel execution
if
has_l2_persistent_map
:
if
has_l2_persistent_map
:
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE_PY
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE_PY
# Wrap the kernel dispatch logic in an external C function
# Wrap the kernel dispatch logic in an external C function
host_func
=
PREDEF_HOST_FUNC_PY
.
format
(
host_func
=
PREDEF_HOST_FUNC_PY
.
format
(
repr
(
list
(
function_informations
.
keys
())),
def_args
,
kernel_launch_code
)
repr
(
list
(
function_informations
.
keys
())),
def_args
,
kernel_launch_code
)
return
host_func
return
host_func
def
generate_l2_persistent_map
(
self
,
function_name
:
str
)
->
str
:
def
generate_l2_persistent_map
(
self
,
function_name
:
str
)
->
str
:
...
@@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -434,23 +442,21 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
if
function_name
not
in
self
.
l2_persistent_map
:
if
function_name
not
in
self
.
l2_persistent_map
:
return
""
return
""
init_l2_persistent_map
=
""
init_l2_persistent_map
=
""
for
buffer_name
,
(
hit_ratio
,
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
# Get persisting_l2_cache_max_size
# Get persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
try
:
try
:
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
except
TypeError
:
except
TypeError
:
# as size_in_bytes may be a symbolic expression
# as size_in_bytes may be a symbolic expression
num_bytes
=
persisting_l2_cache_max_size
num_bytes
=
persisting_l2_cache_max_size
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC_PY
.
format
(
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC_PY
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
return
init_l2_persistent_map
return
init_l2_persistent_map
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
"""Generate Python code to initialize TMA descriptors.
"""Generate Python code to initialize TMA descriptors.
TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects
TMA (Tensor Memory Accelerator) descriptors are opaque CUDA objects
...
@@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -470,28 +476,43 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
return
tma_descriptor_init
return
tma_descriptor_init
# Parse TMA descriptor arguments using the common utility
# Parse TMA descriptor arguments using the common utility
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
desc_name_var_map
,
self
.
_pythonic_expr
)
# Generate Python code from parsed parameters
# Generate Python code from parsed parameters
for
params
in
parsed_params
:
for
params
in
parsed_params
:
if
not
params
.
is_img2col
:
if
not
params
.
is_img2col
:
tma_descriptor_init
+=
TMA_DESC_INIT_FUNC_PY
.
format
(
tma_descriptor_init
+=
TMA_DESC_INIT_FUNC_PY
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
box_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
box_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
else
:
else
:
tma_descriptor_init
+=
TMA_IM2COL_DESC_INIT_FUNC_PY
.
format
(
tma_descriptor_init
+=
TMA_IM2COL_DESC_INIT_FUNC_PY
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_dim
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint64_t(
{
x
}
)"
,
params
.
global_stride
)),
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
", "
.
join
(
map
(
lambda
x
:
f
"cuuint32_t(
{
x
}
)"
,
params
.
element_strides
)),
params
.
element_strides
)),
", "
.
join
(
params
.
lower_corner
),
", "
.
join
(
params
.
lower_corner
),
", "
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
", "
.
join
(
params
.
upper_corner
),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
return
tma_descriptor_init
return
tma_descriptor_init
...
@@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
...
@@ -527,17 +548,14 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
nonlocal
function_params
nonlocal
function_params
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
not
(
hasattr
(
node
,
"op"
)
and
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
return
return
args
=
node
.
args
args
=
node
.
args
if
not
args
or
args
[
0
]
!=
fn
:
if
not
args
or
args
[
0
]
!=
fn
:
return
return
if
len
(
args
)
<
1
+
param_cnt
:
if
len
(
args
)
<
1
+
param_cnt
:
raise
AssertionError
(
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
"tvm_call_packed should have at least 1 argument and match device function parameters"
function_params
=
args
[
1
:
1
+
param_cnt
]
)
function_params
=
args
[
1
:
1
+
param_cnt
]
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
assert
function_params
is
not
None
,
"function_params should not be None"
assert
function_params
is
not
None
,
"function_params should not be None"
...
...
tilelang/jit/adapter/torch/__init__.py
View file @
29051439
from
.metal
import
MetalKernelAdapter
from
.metal
import
MetalKernelAdapter
__all__
=
[
'
MetalKernelAdapter
'
]
__all__
=
[
"
MetalKernelAdapter
"
]
tilelang/jit/adapter/torch/metal.py
View file @
29051439
...
@@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam
...
@@ -12,7 +12,6 @@ from tilelang.engine.param import KernelParam
class
MetalKernelAdapter
(
BaseKernelAdapter
):
class
MetalKernelAdapter
(
BaseKernelAdapter
):
def
__init__
(
def
__init__
(
self
,
self
,
params
:
list
[
KernelParam
],
params
:
list
[
KernelParam
],
...
@@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter):
...
@@ -28,10 +27,10 @@ class MetalKernelAdapter(BaseKernelAdapter):
):
):
self
.
kernel_global_source
=
kernel_global_source
self
.
kernel_global_source
=
kernel_global_source
if
isinstance
(
func_or_mod
,
tir
.
PrimFunc
):
if
isinstance
(
func_or_mod
,
tir
.
PrimFunc
):
func_name
=
func_or_mod
.
attrs
[
'
global_symbol
'
]
func_name
=
func_or_mod
.
attrs
[
"
global_symbol
"
]
else
:
else
:
func_name
=
func_or_mod
.
__name__
func_name
=
func_or_mod
.
__name__
self
.
kernel_name
=
func_name
+
'
_kernel
'
self
.
kernel_name
=
func_name
+
"
_kernel
"
self
.
verbose
=
verbose
self
.
verbose
=
verbose
self
.
block_info
=
[
1
,
1
,
1
]
self
.
block_info
=
[
1
,
1
,
1
]
...
@@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
...
@@ -39,7 +38,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
for
var
,
func
in
device_mod
.
functions
.
items
():
for
var
,
func
in
device_mod
.
functions
.
items
():
assert
var
.
name_hint
==
self
.
kernel_name
assert
var
.
name_hint
==
self
.
kernel_name
thread_extent
=
func
.
attrs
[
'
thread_extent
'
]
thread_extent
=
func
.
attrs
[
"
thread_extent
"
]
for
tag
,
extent
in
thread_extent
.
items
():
for
tag
,
extent
in
thread_extent
.
items
():
if
"threadIdx"
in
tag
:
if
"threadIdx"
in
tag
:
self
.
block_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
self
.
block_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
...
@@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
...
@@ -47,7 +46,7 @@ class MetalKernelAdapter(BaseKernelAdapter):
self
.
grid_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
self
.
grid_info
[
"xyz"
.
index
(
tag
[
-
1
])]
=
extent
break
break
else
:
else
:
raise
AssertionError
(
f
'
no kernel with name
{
func_name
}
'
)
raise
AssertionError
(
f
"
no kernel with name
{
func_name
}
"
)
# print(self.block_info, self.grid_info)
# print(self.block_info, self.grid_info)
super
().
__init__
(
func_or_mod
,
result_idx
=
result_idx
,
params
=
params
)
super
().
__init__
(
func_or_mod
,
result_idx
=
result_idx
,
params
=
params
)
...
@@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter):
...
@@ -55,15 +54,12 @@ class MetalKernelAdapter(BaseKernelAdapter):
_kernel
=
None
_kernel
=
None
def
_convert_torch_func
(
self
)
->
Callable
:
def
_convert_torch_func
(
self
)
->
Callable
:
if
self
.
_kernel
is
None
:
if
self
.
_kernel
is
None
:
_kernel
=
getattr
(
torch
.
mps
.
compile_shader
(
self
.
kernel_global_source
),
self
.
kernel_name
)
_kernel
=
getattr
(
torch
.
mps
.
compile_shader
(
self
.
kernel_global_source
),
self
.
kernel_name
)
_threads
=
[
x
*
y
for
(
x
,
y
)
in
zip
(
self
.
block_info
,
self
.
grid_info
)]
_threads
=
[
x
*
y
for
(
x
,
y
)
in
zip
(
self
.
block_info
,
self
.
grid_info
)]
@
wraps
(
_kernel
)
@
wraps
(
_kernel
)
def
launcher
(
*
args
:
torch
.
Tensor
):
def
launcher
(
*
args
:
torch
.
Tensor
):
return
_kernel
(
return
_kernel
(
*
args
,
*
args
,
threads
=
_threads
,
threads
=
_threads
,
...
...
tilelang/jit/adapter/tvm_ffi.py
View file @
29051439
...
@@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked,
...
@@ -5,6 +5,7 @@ via light-weight callables so that, when the wrapped function is invoked,
the execution observes the same stream context as the active Torch code.
the execution observes the same stream context as the active Torch code.
On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
On non-CUDA builds, the stream/device fall back to 0/CPU semantics.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Callable
,
Any
from
typing
import
Callable
,
Any
...
@@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
...
@@ -31,6 +32,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
- The stream pointer returned is a raw CUDA stream handle compatible with
- The stream pointer returned is a raw CUDA stream handle compatible with
TVM's device API; on CPU or when CUDA is unavailable, we return 0.
TVM's device API; on CPU or when CUDA is unavailable, we return 0.
"""
"""
# Class attributes to store compiled kernel information
# Class attributes to store compiled kernel information
target
:
str
|
Target
=
"cuda"
target
:
str
|
Target
=
"cuda"
ir_module
:
tvm
.
IRModule
|
None
=
None
ir_module
:
tvm
.
IRModule
|
None
=
None
...
@@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
...
@@ -51,19 +53,21 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
dynamic_symbolic_map
:
dict
[
tir
.
Var
,
tuple
[
int
,
int
,
int
]]
|
None
=
None
dynamic_symbolic_map
:
dict
[
tir
.
Var
,
tuple
[
int
,
int
,
int
]]
|
None
=
None
# Stream/device functors are inherited from BaseKernelAdapter
# Stream/device functors are inherited from BaseKernelAdapter
def
__init__
(
self
,
def
__init__
(
params
:
list
[
KernelParam
],
self
,
result_idx
:
list
[
int
],
params
:
list
[
KernelParam
],
target
:
str
|
Target
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
|
Target
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_mod
:
tvm
.
IRModule
|
None
=
None
,
rt_mod
:
tvm
.
runtime
.
Module
|
None
=
None
,
device_mod
:
tvm
.
IRModule
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
rt_mod
:
tvm
.
runtime
.
Module
|
None
=
None
,
device_kernel_source
:
str
|
None
=
None
,
host_kernel_source
:
str
|
None
=
None
,
verbose
:
bool
=
False
,
device_kernel_source
:
str
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
"""Initialize the adapter with the given TIR function or module.
"""Initialize the adapter with the given TIR function or module.
Args:
Args:
...
@@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
...
@@ -113,15 +117,13 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
if
param
in
buffer_map
:
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
buffer
=
buffer_map
[
param
]
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
for
j
,
shape
in
enumerate
(
buffer
.
shape
):
if
(
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
if
isinstance
(
shape
,
tir
.
Var
)
and
(
shape
not
in
dynamic_symbolic_map
)
and
(
shape
not
in
params
):
(
shape
not
in
params
)):
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
dynamic_symbolic_map
[
shape
]
=
(
0
,
i
,
j
)
for
i
,
param
in
enumerate
(
params
):
for
i
,
param
in
enumerate
(
params
):
if
param
in
buffer_map
:
if
param
in
buffer_map
:
buffer
=
buffer_map
[
param
]
buffer
=
buffer_map
[
param
]
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
for
j
,
stride
in
enumerate
(
buffer
.
strides
):
if
(
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
if
isinstance
(
stride
,
tir
.
Var
)
and
(
stride
not
in
dynamic_symbolic_map
)
and
(
stride
not
in
params
):
(
stride
not
in
params
)):
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
dynamic_symbolic_map
[
stride
]
=
(
1
,
i
,
j
)
return
dynamic_symbolic_map
return
dynamic_symbolic_map
...
@@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
...
@@ -197,8 +199,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
# Validate input count strictly
# Validate input count strictly
expected_inputs
=
len
(
self
.
params
)
-
len
(
self
.
result_idx
)
expected_inputs
=
len
(
self
.
params
)
-
len
(
self
.
result_idx
)
if
len
(
inputs
)
!=
expected_inputs
:
if
len
(
inputs
)
!=
expected_inputs
:
raise
ValueError
(
raise
ValueError
(
f
"Kernel expected
{
expected_inputs
}
inputs, but
{
len
(
inputs
)
}
are provided."
)
f
"Kernel expected
{
expected_inputs
}
inputs, but
{
len
(
inputs
)
}
are provided."
)
# Resolve the device used for outputs. Prefer the first tensor input's device
# Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device.
# if available, otherwise use PyTorch's current device.
...
@@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
...
@@ -217,17 +218,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
for
s
in
param_shapes
[
i
]:
for
s
in
param_shapes
[
i
]:
if
isinstance
(
s
,
tir
.
Var
):
if
isinstance
(
s
,
tir
.
Var
):
for
key
in
dynamic_symbolic_map
:
for
key
in
dynamic_symbolic_map
:
if
(
str
(
s
)
==
str
(
key
)):
if
str
(
s
)
==
str
(
key
):
ref_id
,
ref_tensor_idx
,
ref_shape_idx
=
dynamic_symbolic_map
[
ref_id
,
ref_tensor_idx
,
ref_shape_idx
=
dynamic_symbolic_map
[
key
]
key
]
if
ref_id
==
2
:
if
ref_id
==
2
:
shape
.
append
(
inputs
[
ref_tensor_idx
])
shape
.
append
(
inputs
[
ref_tensor_idx
])
elif
ref_id
==
0
:
elif
ref_id
==
0
:
shape
.
append
(
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
shape
[
ref_shape_idx
])
tensor_list
[
ref_tensor_idx
].
shape
[
ref_shape_idx
])
elif
ref_id
==
1
:
elif
ref_id
==
1
:
shape
.
append
(
shape
.
append
(
tensor_list
[
ref_tensor_idx
].
stride
()[
ref_shape_idx
])
tensor_list
[
ref_tensor_idx
].
stride
()[
ref_shape_idx
])
else
:
# Already converted to Python int during initialization
else
:
# Already converted to Python int during initialization
shape
.
append
(
s
)
shape
.
append
(
s
)
...
@@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
...
@@ -235,11 +233,11 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
out_device
=
current_device_functor
()
out_device
=
current_device_functor
()
if
len
(
shape
)
==
0
:
if
len
(
shape
)
==
0
:
param_name
=
self
.
params
[
i
].
name
if
hasattr
(
self
.
params
[
i
],
param_name
=
self
.
params
[
i
].
name
if
hasattr
(
self
.
params
[
i
],
"name"
)
else
f
"parameter_
{
i
}
"
'name'
)
else
f
'parameter_
{
i
}
'
raise
ValueError
(
raise
ValueError
(
f
"Cannot create output tensor (name=
{
param_name
}
) - 0-dimensional tensors are not supported. "
f
"Cannot create output tensor (name=
{
param_name
}
) - 0-dimensional tensors are not supported. "
f
"Expected shape:
{
shape
}
"
)
f
"Expected shape:
{
shape
}
"
)
tensor
=
torch
.
empty
(
*
shape
,
dtype
=
dtype
,
device
=
out_device
)
tensor
=
torch
.
empty
(
*
shape
,
dtype
=
dtype
,
device
=
out_device
)
else
:
else
:
tensor
=
inputs
[
ins_idx
]
tensor
=
inputs
[
ins_idx
]
...
@@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
...
@@ -256,17 +254,19 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
return
func
return
func
@
classmethod
@
classmethod
def
from_database
(
cls
,
def
from_database
(
params
:
list
[
TensorType
],
cls
,
result_idx
:
list
[
int
],
params
:
list
[
TensorType
],
target
:
str
,
result_idx
:
list
[
int
],
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
target
:
str
,
host_kernel_source
:
str
,
func_or_mod
:
tir
.
PrimFunc
|
tvm
.
IRModule
,
device_kernel_source
:
str
,
host_kernel_source
:
str
,
kernel_lib_path
:
str
,
device_kernel_source
:
str
,
verbose
:
bool
=
False
,
kernel_lib_path
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
verbose
:
bool
=
False
,
compile_flags
:
list
[
str
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
):
adapter
=
cls
.
__new__
(
cls
)
adapter
=
cls
.
__new__
(
cls
)
adapter
.
params
=
params
adapter
.
params
=
params
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
adapter
.
result_idx
=
adapter
.
_legalize_result_idx
(
result_idx
)
...
...
tilelang/jit/adapter/utils.py
View file @
29051439
...
@@ -70,7 +70,6 @@ def get_annotated_mod(
...
@@ -70,7 +70,6 @@ def get_annotated_mod(
target_host
:
str
|
Target
|
None
=
None
,
target_host
:
str
|
Target
|
None
=
None
,
model_type
:
Literal
[
"device"
,
"host"
,
"all"
]
=
"all"
,
model_type
:
Literal
[
"device"
,
"host"
,
"all"
]
=
"all"
,
)
->
IRModule
|
tuple
[
IRModule
,
IRModule
]:
)
->
IRModule
|
tuple
[
IRModule
,
IRModule
]:
# Validate model_type early
# Validate model_type early
if
model_type
not
in
{
"device"
,
"host"
,
"all"
}:
if
model_type
not
in
{
"device"
,
"host"
,
"all"
}:
raise
ValueError
(
f
"Invalid model type:
{
model_type
}
"
)
raise
ValueError
(
f
"Invalid model type:
{
model_type
}
"
)
...
@@ -95,21 +94,15 @@ def get_annotated_mod(
...
@@ -95,21 +94,15 @@ def get_annotated_mod(
# Define dispatch dictionary for different model types
# Define dispatch dictionary for different model types
dispatch
=
{
dispatch
=
{
"device"
:
"device"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
lambda
m
:
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
"host"
:
lambda
m
:
tir
.
transform
.
Filter
(
_is_host_call
)(
m
),
"host"
:
"all"
:
lambda
m
:
(
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
tir
.
transform
.
Filter
(
_is_host_call
)(
m
)),
lambda
m
:
tir
.
transform
.
Filter
(
_is_host_call
)(
m
),
"all"
:
lambda
m
:
(
tir
.
transform
.
Filter
(
_is_device_call
)(
m
),
tir
.
transform
.
Filter
(
_is_host_call
)
(
m
)),
}
}
return
dispatch
[
model_type
](
mod
)
return
dispatch
[
model_type
](
mod
)
def
pythonic_expr
(
expr
:
tvm
.
tir
.
PrimExpr
,
def
pythonic_expr
(
expr
:
tvm
.
tir
.
PrimExpr
,
dtype_map
:
dict
[
str
,
str
]
|
None
=
None
,
ignore_cast
:
bool
=
False
)
->
str
:
dtype_map
:
dict
[
str
,
str
]
|
None
=
None
,
ignore_cast
:
bool
=
False
)
->
str
:
"""
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
...
@@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
...
@@ -168,9 +161,23 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
s
=
f
"(
{
type_str
}
)
{
value_str
}
"
s
=
f
"(
{
type_str
}
)
{
value_str
}
"
p
=
PRECEDENCE
.
get
(
type
(
node
),
ATOMIC_PRECEDENCE
)
p
=
PRECEDENCE
.
get
(
type
(
node
),
ATOMIC_PRECEDENCE
)
elif
isinstance
(
elif
isinstance
(
node
,
node
,
(
tvm
.
tir
.
Mul
,
tvm
.
tir
.
FloorDiv
,
tvm
.
tir
.
Add
,
tvm
.
tir
.
Sub
,
tvm
.
tir
.
FloorMod
,
tvm
.
tir
.
LT
,
(
tvm
.
tir
.
LE
,
tvm
.
tir
.
GT
,
tvm
.
tir
.
GE
,
tvm
.
tir
.
EQ
,
tvm
.
tir
.
NE
,
tvm
.
tir
.
And
,
tvm
.
tir
.
Or
)):
tvm
.
tir
.
Mul
,
tvm
.
tir
.
FloorDiv
,
tvm
.
tir
.
Add
,
tvm
.
tir
.
Sub
,
tvm
.
tir
.
FloorMod
,
tvm
.
tir
.
LT
,
tvm
.
tir
.
LE
,
tvm
.
tir
.
GT
,
tvm
.
tir
.
GE
,
tvm
.
tir
.
EQ
,
tvm
.
tir
.
NE
,
tvm
.
tir
.
And
,
tvm
.
tir
.
Or
,
),
):
op_map
=
{
op_map
=
{
tvm
.
tir
.
Mul
:
"*"
,
tvm
.
tir
.
Mul
:
"*"
,
tvm
.
tir
.
FloorDiv
:
"/"
,
tvm
.
tir
.
FloorDiv
:
"/"
,
...
@@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
...
@@ -222,10 +229,7 @@ def pythonic_expr(expr: tvm.tir.PrimExpr,
return
next
(
iter
(
node_to_result_map
[
expr
]),
""
)
return
next
(
iter
(
node_to_result_map
[
expr
]),
""
)
def
maybe_desc_name
(
name
:
str
,
def
maybe_desc_name
(
name
:
str
,
matches
:
list
[
str
],
i
:
int
,
desc_name_map
:
dict
[
str
,
str
]
|
None
=
None
)
->
bool
:
matches
:
list
[
str
],
i
:
int
,
desc_name_map
:
dict
[
str
,
str
]
|
None
=
None
)
->
bool
:
"""
"""
Check if a parameter name corresponds to a TMA descriptor.
Check if a parameter name corresponds to a TMA descriptor.
...
@@ -290,8 +294,7 @@ def parse_function_call_args(
...
@@ -290,8 +294,7 @@ def parse_function_call_args(
else
:
else
:
call_args
.
append
(
match
)
call_args
.
append
(
match
)
if
desc_name_var_map
is
not
None
and
function_params
is
not
None
:
if
desc_name_var_map
is
not
None
and
function_params
is
not
None
:
assert
len
(
call_args
)
<=
len
(
function_params
),
\
assert
len
(
call_args
)
<=
len
(
function_params
),
f
"Too many arguments:
{
len
(
call_args
)
}
>
{
len
(
function_params
)
}
"
f
"Too many arguments:
{
len
(
call_args
)
}
>
{
len
(
function_params
)
}
"
desc_name_var_map
[
match
]
=
function_params
[
len
(
call_args
)
-
1
]
desc_name_var_map
[
match
]
=
function_params
[
len
(
call_args
)
-
1
]
return
call_args
return
call_args
...
@@ -300,12 +303,7 @@ def parse_function_call_args(
...
@@ -300,12 +303,7 @@ def parse_function_call_args(
class
TMADescriptorParams
:
class
TMADescriptorParams
:
"""Parsed TMA descriptor parameters."""
"""Parsed TMA descriptor parameters."""
def
__init__
(
self
,
def
__init__
(
self
,
handle_name
:
str
,
dtype
:
str
,
tensor_rank
:
int
,
global_address
:
Any
,
is_img2col
:
bool
=
False
):
handle_name
:
str
,
dtype
:
str
,
tensor_rank
:
int
,
global_address
:
Any
,
is_img2col
:
bool
=
False
):
self
.
handle_name
=
handle_name
self
.
handle_name
=
handle_name
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
tensor_rank
=
tensor_rank
self
.
tensor_rank
=
tensor_rank
...
@@ -355,22 +353,19 @@ def parse_tma_descriptor_args(
...
@@ -355,22 +353,19 @@ def parse_tma_descriptor_args(
results
=
[]
results
=
[]
for
handle_name
,
_
in
desc_name_map
.
items
():
for
handle_name
,
_
in
desc_name_map
.
items
():
assert
handle_name
in
desc_name_var_map
,
\
assert
handle_name
in
desc_name_var_map
,
f
"Handle name
{
handle_name
}
not found in desc_name_var_map"
f
"Handle name
{
handle_name
}
not found in desc_name_var_map"
desc_var
=
desc_name_var_map
[
handle_name
]
desc_var
=
desc_name_var_map
[
handle_name
]
assert
desc_var
in
tma_descriptor_args
,
\
assert
desc_var
in
tma_descriptor_args
,
f
"TMA descriptor
{
desc_var
}
not found in
{
tma_descriptor_args
}
"
f
"TMA descriptor
{
desc_var
}
not found in
{
tma_descriptor_args
}
"
args
=
tma_descriptor_args
[
desc_var
]
args
=
tma_descriptor_args
[
desc_var
]
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
# Skip __tvm_tensormap_create_tiled and second element (like CUDA version)
if
len
(
args
)
<
3
:
if
len
(
args
)
<
3
:
raise
ValueError
(
raise
ValueError
(
f
"TMA descriptor args too short:
{
len
(
args
)
}
elements, expected at least 3"
)
f
"TMA descriptor args too short:
{
len
(
args
)
}
elements, expected at least 3"
)
tma_create_str
,
_
,
dtype
,
tensor_rank
,
global_address
,
*
remaining_args
=
args
tma_create_str
,
_
,
dtype
,
tensor_rank
,
global_address
,
*
remaining_args
=
args
is_img2col
=
(
tma_create_str
.
value
==
"__tvm_tensormap_create_im2col"
)
is_img2col
=
tma_create_str
.
value
==
"__tvm_tensormap_create_im2col"
# Convert basic fields
# Convert basic fields
dtype
=
pythonic_expr_func
(
dtype
)
dtype
=
pythonic_expr_func
(
dtype
)
...
@@ -386,60 +381,45 @@ def parse_tma_descriptor_args(
...
@@ -386,60 +381,45 @@ def parse_tma_descriptor_args(
# Tiled mode
# Tiled mode
expected_args_len
=
4
*
tensor_rank
+
4
expected_args_len
=
4
*
tensor_rank
+
4
if
len
(
remaining_args
)
<
expected_args_len
:
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
raise
ValueError
(
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides
# Extract dimensions and strides
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_stride
=
[
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]]
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
params
.
box_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]]
]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]]
params
.
box_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
]
]
# Extract remaining parameters
# Extract remaining parameters
try
:
try
:
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
4
*
tensor_rank
:
4
*
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
4
*
tensor_rank
:
4
*
tensor_rank
+
4
]
tensor_rank
+
4
]
params
.
interleave
=
pythonic_expr_func
(
interleave
)
params
.
interleave
=
pythonic_expr_func
(
interleave
)
params
.
swizzle
=
pythonic_expr_func
(
swizzle
)
params
.
swizzle
=
pythonic_expr_func
(
swizzle
)
params
.
l2_promotion
=
pythonic_expr_func
(
l2_promotion
)
params
.
l2_promotion
=
pythonic_expr_func
(
l2_promotion
)
params
.
oob_fill
=
pythonic_expr_func
(
oob_fill
)
params
.
oob_fill
=
pythonic_expr_func
(
oob_fill
)
except
ValueError
as
e
:
except
ValueError
as
e
:
raise
ValueError
(
raise
ValueError
(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
)
from
e
else
:
else
:
# Im2col mode
# Im2col mode
expected_args_len
=
5
*
tensor_rank
+
2
expected_args_len
=
5
*
tensor_rank
+
2
if
len
(
remaining_args
)
<
expected_args_len
:
if
len
(
remaining_args
)
<
expected_args_len
:
raise
ValueError
(
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, "
raise
ValueError
(
f
"expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
f
"Insufficient remaining args: got
{
len
(
remaining_args
)
}
, expected
{
expected_args_len
}
for tensor_rank
{
tensor_rank
}
"
)
# Extract dimensions and strides
# Extract dimensions and strides
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_dim
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[:
tensor_rank
]]
params
.
global_stride
=
[
params
.
global_stride
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]]
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
tensor_rank
:
2
*
tensor_rank
]
params
.
element_strides
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]]
]
params
.
lower_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
-
2
]]
params
.
element_strides
=
[
params
.
upper_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
4
*
tensor_rank
-
2
:
5
*
tensor_rank
-
4
]]
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
2
*
tensor_rank
:
3
*
tensor_rank
]
]
params
.
lower_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
3
*
tensor_rank
:
4
*
tensor_rank
-
2
]
]
params
.
upper_corner
=
[
pythonic_expr_func
(
i
)
for
i
in
remaining_args
[
4
*
tensor_rank
-
2
:
5
*
tensor_rank
-
4
]
]
# Extract remaining parameters
# Extract remaining parameters
try
:
try
:
smem_box_pixel
,
smem_box_channel
,
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
\
smem_box_pixel
,
smem_box_channel
,
interleave
,
swizzle
,
l2_promotion
,
oob_fill
=
remaining_args
[
remaining_args
[
5
*
tensor_rank
-
4
:
5
*
tensor_rank
+
2
]
5
*
tensor_rank
-
4
:
5
*
tensor_rank
+
2
]
params
.
smem_box_pixel
=
pythonic_expr_func
(
smem_box_pixel
)
params
.
smem_box_pixel
=
pythonic_expr_func
(
smem_box_pixel
)
params
.
smem_box_channel
=
pythonic_expr_func
(
smem_box_channel
)
params
.
smem_box_channel
=
pythonic_expr_func
(
smem_box_channel
)
params
.
interleave
=
pythonic_expr_func
(
interleave
)
params
.
interleave
=
pythonic_expr_func
(
interleave
)
...
...
tilelang/jit/adapter/wrapper.py
View file @
29051439
...
@@ -4,9 +4,18 @@ from tilelang import tvm as tvm
...
@@ -4,9 +4,18 @@ from tilelang import tvm as tvm
from
typing
import
Any
from
typing
import
Any
from
tvm
import
IRModule
from
tvm
import
IRModule
from
tvm.target
import
Target
from
tvm.target
import
Target
from
.utils
import
(
is_metal_target
,
match_declare_kernel
,
match_declare_kernel_cpu
,
is_cuda_target
,
from
.utils
import
(
is_hip_target
,
is_cpu_target
,
get_annotated_mod
,
pythonic_expr
,
is_metal_target
,
parse_function_call_args
,
parse_tma_descriptor_args
)
match_declare_kernel
,
match_declare_kernel_cpu
,
is_cuda_target
,
is_hip_target
,
is_cpu_target
,
get_annotated_mod
,
pythonic_expr
,
parse_function_call_args
,
parse_tma_descriptor_args
,
)
import
re
import
re
import
logging
import
logging
import
textwrap
import
textwrap
...
@@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """
...
@@ -129,7 +138,6 @@ TMA_IM2COL_DESC_INIT_FUNC = """
class
BaseWrapper
(
ABC
):
class
BaseWrapper
(
ABC
):
@
abstractmethod
@
abstractmethod
def
wrap
(
self
,
*
args
,
**
kwargs
):
def
wrap
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -163,13 +171,15 @@ class TLCUDASourceWrapper:
...
@@ -163,13 +171,15 @@ class TLCUDASourceWrapper:
host_mod
:
IRModule
|
None
=
None
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
def
__init__
(
scheduled_ir_module
:
IRModule
,
self
,
source
:
str
,
scheduled_ir_module
:
IRModule
,
target
:
Target
,
source
:
str
,
device_mod
:
IRModule
|
None
=
None
,
target
:
Target
,
host_mod
:
IRModule
|
None
=
None
,
device_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
target
=
target
self
.
source
=
source
self
.
source
=
source
...
@@ -211,15 +221,16 @@ class TLCUDASourceWrapper:
...
@@ -211,15 +221,16 @@ class TLCUDASourceWrapper:
for
param
in
self
.
prim_func
.
params
:
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
function_args
.
append
(
"name"
:
buffer
.
data
.
name
,
{
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"* __restrict__"
,
"name"
:
buffer
.
data
.
name
,
})
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"* __restrict__"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
if
dyn_sym
not
in
[
arg
[
"name"
]
for
arg
in
function_args
]:
...
@@ -256,38 +267,40 @@ class TLCUDASourceWrapper:
...
@@ -256,38 +267,40 @@ class TLCUDASourceWrapper:
# Identify the start of the function body to insert arguments
# Identify the start of the function body to insert arguments
index
=
code
.
index
(
"{"
,
index
)
index
=
code
.
index
(
"{"
,
index
)
block_str
=
f
"dim3(
{
self
.
_pythonic_expr
(
block_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
2
])
}
)"
block_str
=
(
grid_str
=
f
"dim3(
{
self
.
_pythonic_expr
(
grid_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
2
])
}
)"
f
"dim3(
{
self
.
_pythonic_expr
(
block_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
block_info
[
2
])
}
)"
)
grid_str
=
(
f
"dim3(
{
self
.
_pythonic_expr
(
grid_info
[
0
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
1
])
}
,
{
self
.
_pythonic_expr
(
grid_info
[
2
])
}
)"
)
smem_str
=
0
if
dynamic_smem_buf
is
None
else
dynamic_smem_buf
smem_str
=
0
if
dynamic_smem_buf
is
None
else
dynamic_smem_buf
init_l2_persistent_map
=
self
.
generate_l2_persistent_map
(
function_name
)
init_l2_persistent_map
=
self
.
generate_l2_persistent_map
(
function_name
)
kernel_launch_code
+=
init_l2_persistent_map
kernel_launch_code
+=
init_l2_persistent_map
if
self
.
use_cooperative_groups
[
function_name
]:
if
self
.
use_cooperative_groups
[
function_name
]:
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
(
assert
len
(
function_params
)
==
len
(
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
args_list
)
),
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
args_array
=
[
f
"(void*)&
{
arg
}
"
for
arg
in
args_list
]
args_array
=
[
f
"(void*)&
{
arg
}
"
for
arg
in
args_list
]
call_args
=
f
"
\t
void*
{
function_name
}
_args[] = {{
{
', '
.
join
(
args_array
)
}
}};
\n
"
call_args
=
f
"
\t
void*
{
function_name
}
_args[] = {{
{
', '
.
join
(
args_array
)
}
}};
\n
"
kernel_launch_code
+=
call_args
kernel_launch_code
+=
call_args
# Using cudaLaunchCooperativeKernel to launch the kernel
# Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code
+=
"
\t
TILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));
\n
"
.
format
(
kernel_launch_code
+=
"
\t
TILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));
\n
"
.
format
(
function_name
,
grid_str
,
block_str
,
function_name
+
"_args"
,
smem_str
)
function_name
,
grid_str
,
block_str
,
function_name
+
"_args"
,
smem_str
)
else
:
else
:
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
args_list
=
parse_function_call_args
(
declaration
,
function_args
,
function_params
,
desc_name_map
,
desc_name_var_map
)
desc_name_map
,
desc_name_var_map
)
assert
len
(
function_params
)
==
len
(
args_list
),
(
assert
len
(
function_params
)
==
len
(
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
args_list
)
),
f
"Function
{
function_name
}
has
{
len
(
function_params
)
}
parameters, but
{
len
(
args_list
)
}
arguments"
call_args
=
", "
.
join
(
args_list
)
call_args
=
", "
.
join
(
args_list
)
kernel_launch_code
+=
f
"
\t
{
function_name
}
<<<
{
grid_str
}
,
{
block_str
}
,
{
smem_str
}
, stream>>>(
{
call_args
}
);
\n
"
kernel_launch_code
+=
f
"
\t
{
function_name
}
<<<
{
grid_str
}
,
{
block_str
}
,
{
smem_str
}
, stream>>>(
{
call_args
}
);
\n
"
kernel_launch_code
+=
f
"
\t
TILELANG_CHECK_LAST_ERROR(
\
"
{
function_name
}
\
"
);
\n
"
kernel_launch_code
+=
f
'
\t
TILELANG_CHECK_LAST_ERROR("
{
function_name
}
");
\n
'
if
has_l2_persistent_map
:
if
has_l2_persistent_map
:
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE
kernel_launch_code
+=
L2_PERSISTENT_MAP_RESET_HANDLE
init_tma_descriptor_args
=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
init_tma_descriptor_args
=
self
.
generate_tma_descriptor_args
(
desc_name_map
,
desc_name_var_map
)
desc_name_var_map
)
kernel_launch_code
=
init_tma_descriptor_args
+
kernel_launch_code
kernel_launch_code
=
init_tma_descriptor_args
+
kernel_launch_code
# Wrap the kernel dispatch logic in an external C function
# Wrap the kernel dispatch logic in an external C function
...
@@ -298,46 +311,63 @@ class TLCUDASourceWrapper:
...
@@ -298,46 +311,63 @@ class TLCUDASourceWrapper:
if
function_name
not
in
self
.
l2_persistent_map
:
if
function_name
not
in
self
.
l2_persistent_map
:
return
""
return
""
init_l2_persistent_map
=
""
init_l2_persistent_map
=
""
for
buffer_name
,
(
hit_ratio
,
for
buffer_name
,
(
hit_ratio
,
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
size_in_bytes
)
in
self
.
l2_persistent_map
[
function_name
].
items
():
# get persisting_l2_cache_max_size
# get persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
from
tilelang.carver.arch.driver
import
get_persisting_l2_cache_max_size
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
persisting_l2_cache_max_size
=
get_persisting_l2_cache_max_size
()
try
:
try
:
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
num_bytes
=
min
(
size_in_bytes
,
persisting_l2_cache_max_size
)
except
Exception
:
except
Exception
:
# as size_in_bytes maybe a symbolic expression
# as size_in_bytes maybe a symbolic expression
num_bytes
=
persisting_l2_cache_max_size
num_bytes
=
persisting_l2_cache_max_size
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC
.
format
(
init_l2_persistent_map
+=
L2_PERSISTENT_MAP_INIT_FUNC
.
format
(
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
buffer_name
,
float
(
hit_ratio
),
self
.
_pythonic_expr
(
num_bytes
))
return
init_l2_persistent_map
return
init_l2_persistent_map
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
def
generate_tma_descriptor_args
(
self
,
desc_name_map
:
dict
[
str
,
str
],
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
desc_name_var_map
:
dict
[
str
,
tvm
.
tir
.
Var
])
->
str
:
tma_descripter_init
=
""
tma_descripter_init
=
""
if
self
.
tma_descriptor_args
is
None
:
if
self
.
tma_descriptor_args
is
None
:
return
tma_descripter_init
return
tma_descripter_init
# Parse TMA descriptor arguments using the common utility
# Parse TMA descriptor arguments using the common utility
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
parsed_params
=
parse_tma_descriptor_args
(
self
.
tma_descriptor_args
,
desc_name_map
,
desc_name_var_map
,
self
.
_pythonic_expr
)
desc_name_var_map
,
self
.
_pythonic_expr
)
# Generate C++ code from parsed parameters
# Generate C++ code from parsed parameters
for
params
in
parsed_params
:
for
params
in
parsed_params
:
if
not
params
.
is_img2col
:
if
not
params
.
is_img2col
:
tma_descripter_init
+=
TMA_DESC_INIT_FUNC
.
format
(
tma_descripter_init
+=
TMA_DESC_INIT_FUNC
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
params
.
dtype
,
","
.
join
(
params
.
box_dim
),
","
.
join
(
params
.
element_strides
),
params
.
interleave
,
params
.
tensor_rank
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
params
.
global_address
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
box_dim
),
","
.
join
(
params
.
element_strides
),
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
else
:
else
:
tma_descripter_init
+=
TMA_IM2COL_DESC_INIT_FUNC
.
format
(
tma_descripter_init
+=
TMA_IM2COL_DESC_INIT_FUNC
.
format
(
params
.
handle_name
,
params
.
dtype
,
params
.
tensor_rank
,
params
.
global_address
,
params
.
handle_name
,
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
params
.
dtype
,
","
.
join
(
params
.
element_strides
),
","
.
join
(
params
.
lower_corner
),
params
.
tensor_rank
,
","
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
global_address
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
)
","
.
join
(
params
.
global_dim
),
","
.
join
(
params
.
global_stride
),
","
.
join
(
params
.
element_strides
),
","
.
join
(
params
.
lower_corner
),
","
.
join
(
params
.
upper_corner
),
params
.
smem_box_channel
,
params
.
smem_box_pixel
,
params
.
interleave
,
params
.
swizzle
,
params
.
l2_promotion
,
params
.
oob_fill
,
)
return
tma_descripter_init
return
tma_descripter_init
...
@@ -347,9 +377,8 @@ class TLCUDASourceWrapper:
...
@@ -347,9 +377,8 @@ class TLCUDASourceWrapper:
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
self
.
device_mod
=
device_mod
self
.
device_mod
=
device_mod
self
.
host_mod
=
host_mod
self
.
host_mod
=
host_mod
assert
(
len
(
self
.
device_mod
.
functions
)
assert
len
(
self
.
device_mod
.
functions
)
>=
1
,
"Device module should have at least one function."
>=
1
),
"Device module should have at least one function."
assert
len
(
self
.
host_mod
.
functions
)
==
1
,
"Only support one function in host module."
assert
(
len
(
self
.
host_mod
.
functions
)
==
1
),
"Only support one function in host module."
block_info_map
=
{}
block_info_map
=
{}
grid_info_map
=
{}
grid_info_map
=
{}
...
@@ -438,8 +467,7 @@ class TLCUDASourceWrapper:
...
@@ -438,8 +467,7 @@ class TLCUDASourceWrapper:
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
if
dynamic_smem_buf
is
not
None
:
if
dynamic_smem_buf
is
not
None
:
# Format the cudaFuncSetAttribute call for dynamic shared memory
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
.
format
(
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
.
format
(
function_name
,
dynamic_smem_buf
)
function_name
,
dynamic_smem_buf
)
# Format the initialization function using the call_str
# Format the initialization function using the call_str
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
return
init_funcs
return
init_funcs
...
@@ -466,17 +494,14 @@ class TLCUDASourceWrapper:
...
@@ -466,17 +494,14 @@ class TLCUDASourceWrapper:
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
def
visitor
(
node
,
fn
=
function_name
,
param_cnt
=
kernel_params_cnt
):
nonlocal
function_params
nonlocal
function_params
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
isinstance
(
node
,
tvm
.
tir
.
Call
):
if
not
(
hasattr
(
node
,
"op"
)
and
if
not
(
hasattr
(
node
,
"op"
)
and
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
node
.
op
==
tvm
.
ir
.
Op
.
get
(
"tir.tvm_call_packed"
)):
return
return
args
=
node
.
args
args
=
node
.
args
if
not
args
or
args
[
0
]
!=
fn
:
if
not
args
or
args
[
0
]
!=
fn
:
return
return
if
len
(
args
)
<
1
+
param_cnt
:
if
len
(
args
)
<
1
+
param_cnt
:
raise
AssertionError
(
raise
AssertionError
(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
"tvm_call_packed should have at least 1 argument and match device function parameters"
function_params
=
args
[
1
:
1
+
param_cnt
]
)
function_params
=
args
[
1
:
1
+
param_cnt
]
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
post_order_visit
(
self
.
host_func
.
body
,
visitor
)
assert
function_params
is
not
None
,
"function_params should not be None"
assert
function_params
is
not
None
,
"function_params should not be None"
...
@@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
...
@@ -564,13 +589,15 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
"uchar"
:
"uint8_t"
,
"uchar"
:
"uint8_t"
,
}
}
def
__init__
(
self
,
def
__init__
(
scheduled_ir_module
:
IRModule
,
self
,
source
:
str
,
scheduled_ir_module
:
IRModule
,
target
:
Target
,
source
:
str
,
device_mod
:
IRModule
|
None
=
None
,
target
:
Target
,
host_mod
:
IRModule
|
None
=
None
,
device_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
super
().
__init__
(
scheduled_ir_module
,
source
,
target
,
device_mod
,
host_mod
,
pass_configs
)
super
().
__init__
(
scheduled_ir_module
,
source
,
target
,
device_mod
,
host_mod
,
pass_configs
)
def
get_init_func
(
self
):
def
get_init_func
(
self
):
...
@@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
...
@@ -580,8 +607,7 @@ class TLHIPSourceWrapper(TLCUDASourceWrapper):
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
for
function_name
,
dynamic_smem_buf
in
self
.
dynamic_smem_buf
.
items
():
if
dynamic_smem_buf
is
not
None
:
if
dynamic_smem_buf
is
not
None
:
# Format the cudaFuncSetAttribute call for dynamic shared memory
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
.
format
(
call_str
+=
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY_HIP
.
format
(
function_name
,
dynamic_smem_buf
)
function_name
,
dynamic_smem_buf
)
# Format the initialization function using the call_str
# Format the initialization function using the call_str
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
init_funcs
=
PREDEF_INIT_FUNC
.
format
(
call_str
)
return
init_funcs
return
init_funcs
...
@@ -623,13 +649,15 @@ class TLCPUSourceWrapper:
...
@@ -623,13 +649,15 @@ class TLCPUSourceWrapper:
host_mod
:
IRModule
|
None
=
None
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
def
__init__
(
self
,
def
__init__
(
scheduled_ir_module
:
IRModule
,
self
,
source
:
str
,
scheduled_ir_module
:
IRModule
,
target
:
Target
,
source
:
str
,
device_mod
:
IRModule
|
None
=
None
,
target
:
Target
,
host_mod
:
IRModule
|
None
=
None
,
device_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
target
=
target
self
.
source
=
source
self
.
source
=
source
...
@@ -658,15 +686,16 @@ class TLCPUSourceWrapper:
...
@@ -658,15 +686,16 @@ class TLCPUSourceWrapper:
for
param
in
self
.
prim_func
.
params
:
for
param
in
self
.
prim_func
.
params
:
if
param
in
self
.
prim_func
.
buffer_map
:
if
param
in
self
.
prim_func
.
buffer_map
:
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
buffer
=
self
.
prim_func
.
buffer_map
[
param
]
function_args
.
append
({
function_args
.
append
(
"name"
:
buffer
.
name
,
{
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"*"
,
"name"
:
buffer
.
name
,
})
"type"
:
self
.
_lookup_type
(
buffer
.
dtype
)
+
"*"
,
}
)
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
elif
isinstance
(
param
,
tvm
.
tir
.
Var
):
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
function_args
.
append
({
"name"
:
param
.
name
,
"type"
:
self
.
_lookup_type
(
param
.
dtype
)})
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
f
"Parameter
{
param
}
is not in the buffer map of the primary function."
)
# Add dynamic symbols as integer arguments
# Add dynamic symbols as integer arguments
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
for
dyn_sym
,
dyn_sym_dtype
in
dynamic_symbolic_set
:
function_args
.
append
({
"name"
:
dyn_sym
,
"type"
:
self
.
_lookup_type
(
dyn_sym_dtype
)})
function_args
.
append
({
"name"
:
dyn_sym
,
"type"
:
self
.
_lookup_type
(
dyn_sym_dtype
)})
...
@@ -686,7 +715,6 @@ class TLCPUSourceWrapper:
...
@@ -686,7 +715,6 @@ class TLCPUSourceWrapper:
_call_str
=
""""""
_call_str
=
""""""
for
function_name
,
_
in
function_informations
.
items
():
for
function_name
,
_
in
function_informations
.
items
():
# Find the location of the global kernel function in the code
# Find the location of the global kernel function in the code
index
=
match_declare_kernel_cpu
(
code
,
function_name
+
"("
)
index
=
match_declare_kernel_cpu
(
code
,
function_name
+
"("
)
...
@@ -706,8 +734,8 @@ class TLCPUSourceWrapper:
...
@@ -706,8 +734,8 @@ class TLCPUSourceWrapper:
def
parse_source_information
(
self
):
def
parse_source_information
(
self
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
,
config
=
self
.
pass_configs
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
,
config
=
self
.
pass_configs
):
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
device_mod
,
host_mod
=
get_annotated_mod
(
self
.
mod
,
self
.
target
)
assert
(
len
(
device_mod
.
functions
)
>=
1
)
,
"Device module should have at least one function."
assert
len
(
device_mod
.
functions
)
>=
1
,
"Device module should have at least one function."
assert
(
len
(
host_mod
.
functions
)
==
1
)
,
"Only support one function in host module."
assert
len
(
host_mod
.
functions
)
==
1
,
"Only support one function in host module."
function_names
=
[]
function_names
=
[]
for
g_var
,
_
in
device_mod
.
functions
.
items
():
for
g_var
,
_
in
device_mod
.
functions
.
items
():
...
@@ -767,14 +795,15 @@ class TLCPUSourceWrapper:
...
@@ -767,14 +795,15 @@ class TLCPUSourceWrapper:
class
TLMetalSourceWrapper
:
class
TLMetalSourceWrapper
:
def
__init__
(
def
__init__
(
self
,
self
,
scheduled_ir_module
:
IRModule
,
scheduled_ir_module
:
IRModule
,
source
:
str
,
source
:
str
,
target
:
Target
,
target
:
Target
,
device_mod
:
IRModule
|
None
=
None
,
device_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
host_mod
:
IRModule
|
None
=
None
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
):
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
mod
=
scheduled_ir_module
self
.
mod
=
scheduled_ir_module
self
.
target
=
target
self
.
target
=
target
self
.
source
=
source
self
.
source
=
source
...
@@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper):
...
@@ -792,6 +821,7 @@ class TLWrapper(BaseWrapper):
"""
"""
A wrapper class for the TileLang backend.
A wrapper class for the TileLang backend.
"""
"""
device_mod
:
IRModule
|
None
=
None
device_mod
:
IRModule
|
None
=
None
host_mod
:
IRModule
|
None
=
None
host_mod
:
IRModule
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
...
@@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper):
...
@@ -836,12 +866,12 @@ class TLWrapper(BaseWrapper):
target
=
self
.
target
,
target
=
self
.
target
,
device_mod
=
self
.
device_mod
,
device_mod
=
self
.
device_mod
,
host_mod
=
self
.
host_mod
,
host_mod
=
self
.
host_mod
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
return
wrapper
.
lib_code
return
wrapper
.
lib_code
class
TLPyWrapper
(
TLWrapper
):
class
TLPyWrapper
(
TLWrapper
):
def
__init__
(
self
,
target
:
Target
):
def
__init__
(
self
,
target
:
Target
):
super
().
__init__
(
target
)
super
().
__init__
(
target
)
...
@@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper):
...
@@ -849,6 +879,7 @@ class TLPyWrapper(TLWrapper):
# assert self.scheduled_ir_module is not None, "Please assign optimized module first."
# assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if
is_cuda_target
(
self
.
target
):
if
is_cuda_target
(
self
.
target
):
from
tilelang.jit.adapter.nvrtc
import
TLNVRTCSourceWrapper
from
tilelang.jit.adapter.nvrtc
import
TLNVRTCSourceWrapper
wrapper_class
=
TLNVRTCSourceWrapper
wrapper_class
=
TLNVRTCSourceWrapper
else
:
else
:
raise
ValueError
(
f
"Unsupported target for NVRTC backend:
{
self
.
target
}
"
)
raise
ValueError
(
f
"Unsupported target for NVRTC backend:
{
self
.
target
}
"
)
...
@@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper):
...
@@ -858,5 +889,6 @@ class TLPyWrapper(TLWrapper):
target
=
self
.
target
,
target
=
self
.
target
,
device_mod
=
self
.
device_mod
,
device_mod
=
self
.
device_mod
,
host_mod
=
self
.
host_mod
,
host_mod
=
self
.
host_mod
,
pass_configs
=
self
.
pass_configs
)
pass_configs
=
self
.
pass_configs
,
)
return
wrapper
.
host_func
,
wrapper
.
function_names
return
wrapper
.
host_func
,
wrapper
.
function_names
tilelang/jit/execution_backend.py
View file @
29051439
...
@@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
...
@@ -46,6 +46,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
# Drop NVRTC if not importable
# Drop NVRTC if not importable
try
:
try
:
from
tilelang.jit.adapter.nvrtc
import
is_nvrtc_available
# lazy
from
tilelang.jit.adapter.nvrtc
import
is_nvrtc_available
# lazy
if
not
is_nvrtc_available
and
"nvrtc"
in
allowed
:
if
not
is_nvrtc_available
and
"nvrtc"
in
allowed
:
allowed
=
[
b
for
b
in
allowed
if
b
!=
"nvrtc"
]
allowed
=
[
b
for
b
in
allowed
if
b
!=
"nvrtc"
]
except
Exception
:
except
Exception
:
...
@@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
...
@@ -89,12 +90,14 @@ def resolve_execution_backend(requested: str | None, target: Target) -> str:
if
req
not
in
allowed_all
:
if
req
not
in
allowed_all
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid execution backend '
{
requested
}
' for target '
{
_target_kind
(
target
)
}
'. "
f
"Invalid execution backend '
{
requested
}
' for target '
{
_target_kind
(
target
)
}
'. "
f
"Allowed:
{
_format_options
(
allowed_all
)
}
. Tip: use execution_backend='auto'."
)
f
"Allowed:
{
_format_options
(
allowed_all
)
}
. Tip: use execution_backend='auto'."
)
# Promote to availability-aware set for nicer errors (e.g., nvrtc not installed)
# Promote to availability-aware set for nicer errors (e.g., nvrtc not installed)
if
req
not
in
allowed_avail
:
if
req
not
in
allowed_avail
:
raise
ValueError
(
raise
ValueError
(
f
"Execution backend '
{
requested
}
' requires extra dependencies and is not available now. "
f
"Execution backend '
{
requested
}
' requires extra dependencies and is not available now. "
f
"Try one of:
{
_format_options
(
allowed_avail
)
}
."
)
f
"Try one of:
{
_format_options
(
allowed_avail
)
}
."
)
return
req
return
req
tilelang/jit/kernel.py
View file @
29051439
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
,
Callable
,
Generic
,
Literal
,
TypeVar
from
typing
import
Any
,
Callable
,
Generic
,
Literal
,
TypeVar
# Python 3.9 compatibility for ParamSpec
# Python 3.9 compatibility for ParamSpec
try
:
try
:
from
typing
import
ParamSpec
from
typing
import
ParamSpec
...
@@ -14,8 +15,7 @@ import tilelang
...
@@ -14,8 +15,7 @@ import tilelang
from
tilelang
import
tvm
from
tilelang
import
tvm
from
tilelang
import
env
from
tilelang
import
env
from
tilelang.engine.param
import
CompiledArtifact
,
KernelParam
from
tilelang.engine.param
import
CompiledArtifact
,
KernelParam
from
tilelang.jit.adapter
import
(
BaseKernelAdapter
,
CtypesKernelAdapter
,
CythonKernelAdapter
,
from
tilelang.jit.adapter
import
BaseKernelAdapter
,
CtypesKernelAdapter
,
CythonKernelAdapter
,
TVMFFIKernelAdapter
,
MetalKernelAdapter
TVMFFIKernelAdapter
,
MetalKernelAdapter
)
from
tilelang.profiler
import
Profiler
,
TensorSupplyType
from
tilelang.profiler
import
Profiler
,
TensorSupplyType
from
tilelang.utils.target
import
determine_target
from
tilelang.utils.target
import
determine_target
from
tilelang.contrib
import
nvcc
as
tl_nvcc
from
tilelang.contrib
import
nvcc
as
tl_nvcc
...
@@ -24,8 +24,8 @@ import os
...
@@ -24,8 +24,8 @@ import os
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_P
=
ParamSpec
(
'
_P
'
)
_P
=
ParamSpec
(
"
_P
"
)
_T
=
TypeVar
(
'
_T
'
)
_T
=
TypeVar
(
"
_T
"
)
class
JITKernel
(
Generic
[
_P
,
_T
]):
class
JITKernel
(
Generic
[
_P
,
_T
]):
...
@@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -41,6 +41,7 @@ class JITKernel(Generic[_P, _T]):
torch_function : Callable
torch_function : Callable
The compiled function that can be invoked as a PyTorch-compatible function.
The compiled function that can be invoked as a PyTorch-compatible function.
"""
"""
prim_func
:
PrimFunc
=
None
prim_func
:
PrimFunc
=
None
artifact
:
CompiledArtifact
=
None
artifact
:
CompiledArtifact
=
None
adapter
:
BaseKernelAdapter
=
None
adapter
:
BaseKernelAdapter
=
None
...
@@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -111,9 +112,7 @@ class JITKernel(Generic[_P, _T]):
if
execution_backend
==
"cython"
:
if
execution_backend
==
"cython"
:
from
tilelang.contrib.cc
import
get_cplus_compiler
from
tilelang.contrib.cc
import
get_cplus_compiler
assert
(
assert
get_cplus_compiler
()
is
not
None
,
"Cython backend requires a C++ compiler, please install or use other backends."
get_cplus_compiler
()
is
not
None
),
"Cython backend requires a C++ compiler, please install or use other backends."
if
from_database
:
if
from_database
:
return
return
...
@@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -200,8 +199,7 @@ class JITKernel(Generic[_P, _T]):
"""
"""
return
self
.
torch_function
(
*
args
,
**
kwds
)
return
self
.
torch_function
(
*
args
,
**
kwds
)
def
_compile_and_create_adapter
(
self
,
tilelang_func
:
PrimFunc
,
def
_compile_and_create_adapter
(
self
,
tilelang_func
:
PrimFunc
,
out_idx
:
list
[
int
])
->
BaseKernelAdapter
:
out_idx
:
list
[
int
])
->
BaseKernelAdapter
:
"""
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
...
@@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]):
...
@@ -233,7 +231,8 @@ class JITKernel(Generic[_P, _T]):
target
=
target
,
target
=
target
,
target_host
=
target_host
,
target_host
=
target_host
,
enable_host_codegen
=
enable_host_codegen
,
enable_host_codegen
=
enable_host_codegen
,
enable_device_compile
=
enable_device_compile
)
enable_device_compile
=
enable_device_compile
,
)
self
.
artifact
=
artifact
self
.
artifact
=
artifact
...
@@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -241,7 +240,7 @@ class JITKernel(Generic[_P, _T]):
if
execution_backend
==
"tvm_ffi"
:
if
execution_backend
==
"tvm_ffi"
:
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
# But we need to ensure that the runtime is enabled and the runtime module is not None.
# But we need to ensure that the runtime is enabled and the runtime module is not None.
assert
(
artifact
.
rt_mod
is
not
None
)
,
"tvm_ffi backend requires a runtime module."
assert
artifact
.
rt_mod
is
not
None
,
"tvm_ffi backend requires a runtime module."
adapter
=
TVMFFIKernelAdapter
(
adapter
=
TVMFFIKernelAdapter
(
params
=
artifact
.
params
,
params
=
artifact
.
params
,
result_idx
=
out_idx
,
result_idx
=
out_idx
,
...
@@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -283,6 +282,7 @@ class JITKernel(Generic[_P, _T]):
)
)
elif
execution_backend
==
"nvrtc"
:
elif
execution_backend
==
"nvrtc"
:
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
adapter
=
NVRTCKernelAdapter
(
adapter
=
NVRTCKernelAdapter
(
params
=
artifact
.
params
,
params
=
artifact
.
params
,
result_idx
=
out_idx
,
result_idx
=
out_idx
,
...
@@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]):
...
@@ -315,16 +315,18 @@ class JITKernel(Generic[_P, _T]):
return
adapter
return
adapter
def
_create_adapter_from_database
(
self
,
def
_create_adapter_from_database
(
params
:
list
[
KernelParam
],
self
,
result_idx
:
list
[
int
]
|
int
,
params
:
list
[
KernelParam
],
target
:
str
|
Target
,
result_idx
:
list
[
int
]
|
int
,
func_or_mod
:
PrimFunc
|
tvm
.
runtime
.
Module
,
target
:
str
|
Target
,
host_kernel_source
:
str
,
func_or_mod
:
PrimFunc
|
tvm
.
runtime
.
Module
,
device_kernel_source
:
str
,
host_kernel_source
:
str
,
kernel_lib_path
:
str
,
device_kernel_source
:
str
,
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
kernel_lib_path
:
str
,
compile_flags
:
list
[
str
]
|
None
=
None
)
->
BaseKernelAdapter
:
pass_configs
:
dict
[
str
,
Any
]
|
None
=
None
,
compile_flags
:
list
[
str
]
|
None
=
None
,
)
->
BaseKernelAdapter
:
target
=
self
.
target
target
=
self
.
target
execution_backend
=
self
.
execution_backend
execution_backend
=
self
.
execution_backend
...
@@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -366,6 +368,7 @@ class JITKernel(Generic[_P, _T]):
)
)
elif
execution_backend
==
"nvrtc"
:
elif
execution_backend
==
"nvrtc"
:
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
from
tilelang.jit.adapter
import
NVRTCKernelAdapter
adapter
=
NVRTCKernelAdapter
.
from_database
(
adapter
=
NVRTCKernelAdapter
.
from_database
(
params
=
params
,
params
=
params
,
result_idx
=
result_idx
,
result_idx
=
result_idx
,
...
@@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -402,8 +405,7 @@ class JITKernel(Generic[_P, _T]):
"""
"""
return
cls
(
func
=
tilelang_func
,
**
kwargs
)
return
cls
(
func
=
tilelang_func
,
**
kwargs
)
def
get_profiler
(
self
,
def
get_profiler
(
self
,
tensor_supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Auto
)
->
Profiler
:
tensor_supply_type
:
TensorSupplyType
=
TensorSupplyType
.
Auto
)
->
Profiler
:
"""
"""
Creates a profiler to benchmark the compiled runtime module.
Creates a profiler to benchmark the compiled runtime module.
...
@@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -417,8 +419,7 @@ class JITKernel(Generic[_P, _T]):
Profiler
Profiler
A Profiler instance for benchmarking the runtime module.
A Profiler instance for benchmarking the runtime module.
"""
"""
return
Profiler
(
self
.
params
,
self
.
out_idx
,
return
Profiler
(
self
.
params
,
self
.
out_idx
,
tensor_supply_type
).
with_default_adapter
(
self
.
adapter
)
tensor_supply_type
).
with_default_adapter
(
self
.
adapter
)
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
True
)
->
str
:
def
get_kernel_source
(
self
,
kernel_only
:
bool
=
True
)
->
str
:
"""
"""
...
@@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]):
...
@@ -507,21 +508,19 @@ class JITKernel(Generic[_P, _T]):
dir_path
=
os
.
path
.
dirname
(
kernel_path
)
dir_path
=
os
.
path
.
dirname
(
kernel_path
)
if
dir_path
:
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
kernel_path
,
'w'
)
as
f
:
with
open
(
kernel_path
,
"w"
)
as
f
:
f
.
write
(
self
.
get_kernel_source
())
f
.
write
(
self
.
get_kernel_source
())
if
host_path
is
not
None
:
if
host_path
is
not
None
:
dir_path
=
os
.
path
.
dirname
(
host_path
)
dir_path
=
os
.
path
.
dirname
(
host_path
)
if
dir_path
:
if
dir_path
:
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
os
.
makedirs
(
dir_path
,
exist_ok
=
True
)
with
open
(
host_path
,
'w'
)
as
f
:
with
open
(
host_path
,
"w"
)
as
f
:
f
.
write
(
self
.
get_host_source
())
f
.
write
(
self
.
get_host_source
())
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Failed to export sources:
{
e
}
"
)
logger
.
error
(
f
"Failed to export sources:
{
e
}
"
)
# Backward compatibility alias (deprecated)
# Backward compatibility alias (deprecated)
def
print_source_code
(
self
,
def
print_source_code
(
self
,
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
,
file
:
str
|
None
=
None
)
->
None
:
which
:
Literal
[
"kernel"
,
"host"
,
"both"
]
=
"kernel"
,
file
:
str
|
None
=
None
)
->
None
:
"""
"""
Deprecated: use show_source() or export_sources() instead.
Deprecated: use show_source() or export_sources() instead.
...
@@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]):
...
@@ -541,16 +540,14 @@ class JITKernel(Generic[_P, _T]):
>>> # Old API (still works but deprecated)
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
"""
"""
logger
.
warning
(
logger
.
warning
(
"print_source_code is deprecated; use show_source() or export_sources() instead."
)
"print_source_code is deprecated; use show_source() or export_sources() instead."
)
if
file
is
not
None
:
if
file
is
not
None
:
# Historical behavior wrote only kernel source when file provided
# Historical behavior wrote only kernel source when file provided
self
.
export_sources
(
kernel_path
=
file
)
self
.
export_sources
(
kernel_path
=
file
)
else
:
else
:
self
.
show_source
(
which
=
which
)
self
.
show_source
(
which
=
which
)
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
def
update_tuner_result
(
self
,
latency
:
float
,
config
:
dict
[
str
,
Any
],
ref_latency
:
float
)
->
JITKernel
:
ref_latency
:
float
)
->
JITKernel
:
"""
"""
Updates the tuning results for this kernel.
Updates the tuning results for this kernel.
...
@@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -651,8 +648,7 @@ class JITKernel(Generic[_P, _T]):
verbose
=
self
.
verbose
verbose
=
self
.
verbose
# Ensure target is set so nvcc picks correct arch via Target.current()
# Ensure target is set so nvcc picks correct arch via Target.current()
with
self
.
target
:
with
self
.
target
:
return
tl_nvcc
.
get_ptx_from_source
(
return
tl_nvcc
.
get_ptx_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_ptx
(
self
)
->
None
:
def
show_ptx
(
self
)
->
None
:
"""
"""
...
@@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]):
...
@@ -714,8 +710,7 @@ class JITKernel(Generic[_P, _T]):
if
verbose
is
None
:
if
verbose
is
None
:
verbose
=
self
.
verbose
verbose
=
self
.
verbose
with
self
.
target
:
with
self
.
target
:
return
tl_nvcc
.
get_sass_from_source
(
return
tl_nvcc
.
get_sass_from_source
(
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
code
,
compile_flags
=
self
.
compile_flags
,
verbose
=
verbose
)
def
show_sass
(
self
)
->
None
:
def
show_sass
(
self
)
->
None
:
"""
"""
...
...
tilelang/language/__init__.py
View file @
29051439
"""The language interface for tl programs."""
"""The language interface for tl programs."""
from
__future__
import
annotations
from
__future__
import
annotations
# from .parser import *
# from .parser import *
...
@@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401
...
@@ -102,7 +103,10 @@ from .utils import index_to_coordinates # noqa: F401
from
.symbolics
import
dynamic
,
symbolic
# noqa: F401
from
.symbolics
import
dynamic
,
symbolic
# noqa: F401
from
.annotations
import
(
# noqa: F401
from
.annotations
import
(
# noqa: F401
use_swizzle
,
annotate_layout
,
annotate_safe_value
,
annotate_l2_hit_ratio
,
use_swizzle
,
annotate_layout
,
annotate_safe_value
,
annotate_l2_hit_ratio
,
)
)
...
...
tilelang/language/allocate.py
View file @
29051439
...
@@ -13,8 +13,10 @@ Available allocation functions:
...
@@ -13,8 +13,10 @@ Available allocation functions:
Each function takes shape and dtype parameters and returns a TVM buffer object
Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
with the appropriate memory scope.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
TypeVar
,
overload
,
Literal
,
Callable
from
typing
import
TypeVar
,
overload
,
Literal
,
Callable
# Python 3.9 compatibility for advanced typing features (PEP 646)
# Python 3.9 compatibility for advanced typing features (PEP 646)
try
:
try
:
from
typing
import
TypeVarTuple
,
Unpack
# type: ignore[attr-defined]
from
typing
import
TypeVarTuple
,
Unpack
# type: ignore[attr-defined]
...
@@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype
...
@@ -30,13 +32,11 @@ from .v2.dtypes import dtype as tl_dtype
from
.v2.builder
import
OutTensor
from
.v2.builder
import
OutTensor
from
.v2.annot
import
Tensor
,
SharedBuffer
,
LocalBuffer
,
FragmentBuffer
from
.v2.annot
import
Tensor
,
SharedBuffer
,
LocalBuffer
,
FragmentBuffer
_Shapes
=
TypeVarTuple
(
'
_Shapes
'
)
_Shapes
=
TypeVarTuple
(
"
_Shapes
"
)
_DType
=
TypeVar
(
'
_DType
'
)
_DType
=
TypeVar
(
"
_DType
"
)
def
alloc_shared
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
def
alloc_shared
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"shared.dyn"
)
->
SharedBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
dtype
:
_DType
,
scope
=
"shared.dyn"
)
->
SharedBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a shared memory buffer for inter-thread communication.
"""Allocate a shared memory buffer for inter-thread communication.
Args:
Args:
...
@@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]],
...
@@ -54,9 +54,7 @@ def alloc_shared(shape: tuple[Unpack[_Shapes]],
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
def
alloc_local
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
def
alloc_local
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local"
)
->
LocalBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
dtype
:
_DType
,
scope
=
"local"
)
->
LocalBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a local memory buffer for thread-private storage.
"""Allocate a local memory buffer for thread-private storage.
Args:
Args:
...
@@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]],
...
@@ -70,9 +68,9 @@ def alloc_local(shape: tuple[Unpack[_Shapes]],
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
def
alloc_fragment
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
def
alloc_fragment
(
dtype
:
_DType
,
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
_DType
,
scope
=
"local.fragment"
scope
=
"local.fragment"
)
->
FragmentBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
)
->
FragmentBuffer
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
"""Allocate a fragment memory buffer for specialized operations.
"""Allocate a fragment memory buffer for specialized operations.
Args:
Args:
...
@@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]],
...
@@ -87,16 +85,11 @@ def alloc_fragment(shape: tuple[Unpack[_Shapes]],
@
overload
@
overload
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
'local.var'
)
->
Buffer
:
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
"local.var"
)
->
Buffer
:
...
...
@
overload
@
overload
def
alloc_var
(
dtype
:
str
,
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
"local.var"
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
scope
:
str
=
'local.var'
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
...
@@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
...
@@ -142,8 +135,7 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
raise
TypeError
(
"Scope must be provided as a string in alloc_var."
)
raise
TypeError
(
"Scope must be provided as a string in alloc_var."
)
parsed_scope
=
parsed_scope_arg
parsed_scope
=
parsed_scope_arg
elif
len
(
args
)
>
2
:
elif
len
(
args
)
>
2
:
raise
TypeError
(
raise
TypeError
(
f
"alloc_var expected at most 3 positional arguments but got
{
len
(
args
)
+
1
}
."
)
f
"alloc_var expected at most 3 positional arguments but got
{
len
(
args
)
+
1
}
."
)
if
not
isinstance
(
parsed_scope
,
str
):
if
not
isinstance
(
parsed_scope
,
str
):
raise
TypeError
(
"Scope must be a string in alloc_var."
)
raise
TypeError
(
"Scope must be a string in alloc_var."
)
...
@@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
...
@@ -274,13 +266,10 @@ def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
@
overload
@
overload
def
empty
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
def
empty
(
shape
:
tuple
[
Unpack
[
_Shapes
]],
dtype
:
str
=
"float32"
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
...
dtype
:
str
=
'float32'
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
...
def
empty
(
*
shape
:
Unpack
[
_Shapes
],
def
empty
(
*
shape
:
Unpack
[
_Shapes
],
dtype
:
str
=
"float32"
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
dtype
:
str
=
'float32'
)
->
Tensor
[
Callable
[[
Unpack
[
_Shapes
]]],
_DType
]:
if
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
(
tuple
,
list
)):
if
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
(
tuple
,
list
)):
return
OutTensor
(
shape
[
0
],
dtype
)
return
OutTensor
(
shape
[
0
],
dtype
)
elif
len
(
shape
)
==
2
and
isinstance
(
shape
[
0
],
(
tuple
,
list
))
and
isinstance
(
shape
[
1
],
str
):
elif
len
(
shape
)
==
2
and
isinstance
(
shape
[
0
],
(
tuple
,
list
))
and
isinstance
(
shape
[
1
],
str
):
...
@@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes],
...
@@ -288,4 +277,4 @@ def empty(*shape: Unpack[_Shapes],
elif
all
([
isinstance
(
x
,
(
int
,
PrimExpr
))
for
x
in
shape
]):
elif
all
([
isinstance
(
x
,
(
int
,
PrimExpr
))
for
x
in
shape
]):
return
OutTensor
(
shape
,
dtype
)
return
OutTensor
(
shape
,
dtype
)
else
:
else
:
raise
RuntimeError
(
f
'
Invalid shape
{
shape
}
'
)
raise
RuntimeError
(
f
"
Invalid shape
{
shape
}
"
)
tilelang/language/annotations.py
View file @
29051439
"""Annotation helpers exposed on the TileLang language surface."""
"""Annotation helpers exposed on the TileLang language surface."""
from
typing
import
Callable
from
typing
import
Callable
from
tilelang.layout
import
Layout
from
tilelang.layout
import
Layout
...
...
tilelang/language/ast/__init__.py
View file @
29051439
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# This file is modified from the original version,
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
# which is part of the TVM project (https://tvm.apache.org/).
"""Package tvm.script.ir_builder.tir"""
"""Package tvm.script.ir_builder.tir"""
from
.ir
import
*
# noqa: F401
from
.ir
import
*
# noqa: F401
from
.ir
import
boolean
as
bool
# noqa: F401
from
.ir
import
boolean
as
bool
# noqa: F401
from
.ir
import
buffer
as
Buffer
# noqa: F401
from
.ir
import
buffer
as
Buffer
# noqa: F401
...
...
Prev
1
…
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment