Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9241f4fd
Unverified
Commit
9241f4fd
authored
Sep 22, 2025
by
Lifu Huang
Committed by
GitHub
Sep 22, 2025
Browse files
Move cached kernel to srt.utils (#10776)
parent
063c3791
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
122 deletions
+120
-122
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
+1
-1
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+118
-0
python/sglang/utils.py
python/sglang/utils.py
+0
-120
No files found.
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
View file @
9241f4fd
...
@@ -5,7 +5,7 @@ import triton
...
@@ -5,7 +5,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.utils
import
cached_triton_kernel
from
sglang.
srt.
utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
...
...
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
View file @
9241f4fd
...
@@ -3,7 +3,7 @@ import triton
...
@@ -3,7 +3,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.utils
import
cached_triton_kernel
from
sglang.
srt.
utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
...
...
python/sglang/srt/utils.py
View file @
9241f4fd
...
@@ -22,6 +22,7 @@ import ctypes
...
@@ -22,6 +22,7 @@ import ctypes
import
dataclasses
import
dataclasses
import
functools
import
functools
import
importlib
import
importlib
import
inspect
import
io
import
io
import
ipaddress
import
ipaddress
import
itertools
import
itertools
...
@@ -3224,3 +3225,120 @@ def get_extend_input_len_swa_limit(
...
@@ -3224,3 +3225,120 @@ def get_extend_input_len_swa_limit(
# and we can only free out-of-sliding-window kv indices after each prefill.
# and we can only free out-of-sliding-window kv indices after each prefill.
# 3. page_size is because we want to have 1 token extra for generated tokens.
# 3. page_size is because we want to have 1 token extra for generated tokens.
return
page_size
+
2
*
max
(
sliding_window_size
,
chunked_prefill_size
)
return
page_size
+
2
*
max
(
sliding_window_size
,
chunked_prefill_size
)
class
CachedKernel
:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def
__init__
(
self
,
fn
,
key_fn
=
None
):
self
.
fn
=
fn
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
original_fn
=
fn
.
fn
self
.
signature
=
inspect
.
signature
(
original_fn
)
self
.
param_names
=
tuple
(
self
.
signature
.
parameters
.
keys
())
self
.
num_args
=
len
(
self
.
param_names
)
# Check that no parameters have default values
for
name
,
param
in
self
.
signature
.
parameters
.
items
():
assert
(
param
.
default
is
inspect
.
Parameter
.
empty
),
f
"Parameter '
{
name
}
' has a default value. Default parameters are not supported in cached kernels."
functools
.
update_wrapper
(
self
,
original_fn
)
self
.
kernel_cache
=
{}
# Store the key function
self
.
key_fn
=
key_fn
def
__getitem__
(
self
,
grid
):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert
(
isinstance
(
grid
,
tuple
)
and
len
(
grid
)
<=
3
),
"Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if
len
(
grid
)
<
3
:
grid
=
grid
+
(
1
,)
*
(
3
-
len
(
grid
))
def
launcher
(
*
args
,
**
kwargs
):
cache_key
=
self
.
key_fn
(
args
,
kwargs
)
cached_kernel
=
self
.
kernel_cache
.
get
(
cache_key
)
if
cached_kernel
is
None
:
# First time: compile and cache the kernel
cached_kernel
=
self
.
fn
[
grid
](
*
args
,
**
kwargs
)
self
.
kernel_cache
[
cache_key
]
=
cached_kernel
return
cached_kernel
else
:
# Use cached kernel
all_args
=
self
.
_build_args
(
args
,
kwargs
)
cached_kernel
[
grid
](
*
all_args
)
return
cached_kernel
return
launcher
def
_build_args
(
self
,
args
,
kwargs
):
"""
Build the complete argument list for kernel invocation.
"""
complete_args
=
list
(
args
)
for
i
in
range
(
len
(
args
),
self
.
num_args
):
name
=
self
.
param_names
[
i
]
value
=
kwargs
.
get
(
name
,
inspect
.
Parameter
.
empty
)
if
value
is
not
inspect
.
Parameter
.
empty
:
complete_args
.
append
(
value
)
else
:
raise
ValueError
(
f
"Missing argument:
{
name
}
"
)
return
complete_args
def
_clear_cache
(
self
):
"""
Clear the kernel cache for testing purposes.
"""
self
.
kernel_cache
.
clear
()
def
cached_triton_kernel
(
key_fn
=
None
):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def
decorator
(
fn
):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
python/sglang/utils.py
View file @
9241f4fd
"""Common utilities"""
"""Common utilities"""
import
functools
import
importlib
import
importlib
import
inspect
import
json
import
json
import
logging
import
logging
import
os
import
os
...
@@ -24,7 +22,6 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
...
@@ -24,7 +22,6 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
import
numpy
as
np
import
numpy
as
np
import
pybase64
import
pybase64
import
requests
import
requests
import
triton
from
IPython.display
import
HTML
,
display
from
IPython.display
import
HTML
,
display
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -552,120 +549,3 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
...
@@ -552,120 +549,3 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module
=
importlib
.
import_module
(
module_name
)
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
obj_name
)
return
getattr
(
module
,
obj_name
)
class
CachedKernel
:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def
__init__
(
self
,
fn
,
key_fn
=
None
):
self
.
fn
=
fn
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
original_fn
=
fn
.
fn
self
.
signature
=
inspect
.
signature
(
original_fn
)
self
.
param_names
=
tuple
(
self
.
signature
.
parameters
.
keys
())
self
.
num_args
=
len
(
self
.
param_names
)
# Check that no parameters have default values
for
name
,
param
in
self
.
signature
.
parameters
.
items
():
assert
(
param
.
default
is
inspect
.
Parameter
.
empty
),
f
"Parameter '
{
name
}
' has a default value. Default parameters are not supported in cached kernels."
functools
.
update_wrapper
(
self
,
original_fn
)
self
.
kernel_cache
=
{}
# Store the key function
self
.
key_fn
=
key_fn
def
__getitem__
(
self
,
grid
):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert
(
isinstance
(
grid
,
tuple
)
and
len
(
grid
)
<=
3
),
"Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if
len
(
grid
)
<
3
:
grid
=
grid
+
(
1
,)
*
(
3
-
len
(
grid
))
def
launcher
(
*
args
,
**
kwargs
):
cache_key
=
self
.
key_fn
(
args
,
kwargs
)
cached_kernel
=
self
.
kernel_cache
.
get
(
cache_key
)
if
cached_kernel
is
None
:
# First time: compile and cache the kernel
cached_kernel
=
self
.
fn
[
grid
](
*
args
,
**
kwargs
)
self
.
kernel_cache
[
cache_key
]
=
cached_kernel
return
cached_kernel
else
:
# Use cached kernel
all_args
=
self
.
_build_args
(
args
,
kwargs
)
cached_kernel
[
grid
](
*
all_args
)
return
cached_kernel
return
launcher
def
_build_args
(
self
,
args
,
kwargs
):
"""
Build the complete argument list for kernel invocation.
"""
complete_args
=
list
(
args
)
for
i
in
range
(
len
(
args
),
self
.
num_args
):
name
=
self
.
param_names
[
i
]
value
=
kwargs
.
get
(
name
,
inspect
.
Parameter
.
empty
)
if
value
is
not
inspect
.
Parameter
.
empty
:
complete_args
.
append
(
value
)
else
:
raise
ValueError
(
f
"Missing argument:
{
name
}
"
)
return
complete_args
def
_clear_cache
(
self
):
"""
Clear the kernel cache for testing purposes.
"""
self
.
kernel_cache
.
clear
()
def
cached_triton_kernel
(
key_fn
=
None
):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def
decorator
(
fn
):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment