Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
635ccda6
"docs/vscode:/vscode.git/clone" did not exist on "229dde7eb5f8a1bf054ee0d3bc711744f0b34c0b"
Unverified
Commit
635ccda6
authored
Sep 21, 2025
by
Lifu Huang
Committed by
GitHub
Sep 21, 2025
Browse files
[4/4] Introduce CachedKernel to reduce CSGMV kernel launch overheads by 60% (#10709)
parent
1c3dbad8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
0 deletions
+118
-0
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
+2
-0
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
+2
-0
python/sglang/utils.py
python/sglang/utils.py
+114
-0
No files found.
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py
View file @
635ccda6
...
...
@@ -5,8 +5,10 @@ import triton
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
@
triton
.
jit
def
_chunked_lora_expand_kernel
(
# Pointers to matrices
...
...
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py
View file @
635ccda6
...
...
@@ -3,8 +3,10 @@ import triton
import
triton.language
as
tl
from
sglang.srt.lora.utils
import
LoRABatchInfo
from
sglang.utils
import
cached_triton_kernel
@
cached_triton_kernel
(
lambda
_
,
kwargs
:
(
kwargs
[
"NUM_SLICES"
],
kwargs
[
"BLOCK_M"
]))
@
triton
.
jit
def
_chunked_lora_shrink_kernel
(
# Pointers to matrices
...
...
python/sglang/utils.py
View file @
635ccda6
"""Common utilities"""
import
functools
import
importlib
import
inspect
import
json
import
logging
import
os
...
...
@@ -21,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
import
numpy
as
np
import
pybase64
import
requests
import
triton
from
IPython.display
import
HTML
,
display
from
pydantic
import
BaseModel
from
tqdm
import
tqdm
...
...
@@ -540,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
obj_name
)
class
CachedKernel
:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def
__init__
(
self
,
fn
,
key_fn
=
None
):
self
.
fn
=
fn
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
original_fn
=
fn
.
fn
self
.
signature
=
inspect
.
signature
(
original_fn
)
self
.
param_names
=
tuple
(
self
.
signature
.
parameters
.
keys
())
self
.
num_args
=
len
(
self
.
param_names
)
# Check that no parameters have default values
for
name
,
param
in
self
.
signature
.
parameters
.
items
():
assert
(
param
.
default
is
inspect
.
Parameter
.
empty
),
f
"Parameter '
{
name
}
' has a default value. Default parameters are not supported in cached kernels."
functools
.
update_wrapper
(
self
,
original_fn
)
self
.
kernel_cache
=
{}
# Store the key function
self
.
key_fn
=
key_fn
def
__getitem__
(
self
,
grid
):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert
(
isinstance
(
grid
,
tuple
)
and
len
(
grid
)
<=
3
),
"Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if
len
(
grid
)
<
3
:
grid
=
grid
+
(
1
,)
*
(
3
-
len
(
grid
))
def
launcher
(
*
args
,
**
kwargs
):
cache_key
=
self
.
key_fn
(
args
,
kwargs
)
cached_kernel
=
self
.
kernel_cache
.
get
(
cache_key
)
if
cached_kernel
is
None
:
# First time: compile and cache the kernel
cached_kernel
=
self
.
fn
[
grid
](
*
args
,
**
kwargs
)
self
.
kernel_cache
[
cache_key
]
=
cached_kernel
return
cached_kernel
else
:
# Use cached kernel
all_args
=
self
.
_build_args
(
args
,
kwargs
)
cached_kernel
[
grid
](
*
all_args
)
return
cached_kernel
return
launcher
def
_build_args
(
self
,
args
,
kwargs
):
"""
Build the complete argument list for kernel invocation.
"""
complete_args
=
list
(
args
)
for
i
in
range
(
len
(
args
),
self
.
num_args
):
name
=
self
.
param_names
[
i
]
value
=
kwargs
.
get
(
name
,
inspect
.
Parameter
.
empty
)
if
value
is
not
inspect
.
Parameter
.
empty
:
complete_args
.
append
(
value
)
else
:
raise
ValueError
(
f
"Missing argument:
{
name
}
"
)
return
complete_args
def
cached_triton_kernel
(
key_fn
=
None
):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def
decorator
(
fn
):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment