Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a29dd950
Unverified
Commit
a29dd950
authored
Dec 30, 2024
by
mobicham
Committed by
GitHub
Dec 30, 2024
Browse files
Add GemLite caching after each capture (#2669)
parent
9c6ba248
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
3 deletions
+21
-3
python/sglang/srt/layers/torchao_utils.py
python/sglang/srt/layers/torchao_utils.py
+17
-3
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+4
-0
No files found.
python/sglang/srt/layers/torchao_utils.py
View file @
a29dd950
...
...
@@ -11,6 +11,22 @@ import torch
logger
=
logging
.
getLogger
(
__name__
)
def
get_gemlite_cache_path
()
->
str
:
return
f
"/tmp/
{
pwd
.
getpwuid
(
os
.
getuid
()).
pw_gecos
}
_gemlite.json"
def
save_gemlite_cache
(
print_error
:
bool
=
False
)
->
bool
:
try
:
from
gemlite.core
import
GemLiteLinearTriton
GemLiteLinearTriton
.
cache_config
(
get_gemlite_cache_path
())
except
Exception
:
if
print_error
:
logger
.
error
(
"Failed to save the GemLite cache."
)
return
False
return
True
def
apply_torchao_config_to_model
(
model
:
torch
.
nn
.
Module
,
torchao_config
:
str
,
filter_fn
=
None
):
...
...
@@ -74,9 +90,7 @@ def apply_torchao_config_to_model(
)
# try to load gemlite kernel config
GemLiteLinearTriton
.
load_config
(
f
"/tmp/
{
pwd
.
getpwuid
(
os
.
getuid
()).
pw_gecos
}
_gemlite.json"
)
GemLiteLinearTriton
.
load_config
(
get_gemlite_cache_path
())
elif
"fp8wo"
in
torchao_config
:
# this requires newer hardware
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
a29dd950
...
...
@@ -31,6 +31,7 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessorOutput
,
)
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
maybe_torch_compile
,
monkey_patch_vllm_all_gather
...
...
@@ -276,6 +277,9 @@ class CudaGraphRunner:
self
.
graphs
[
bs
]
=
graph
self
.
output_buffers
[
bs
]
=
output_buffers
# Save gemlite cache after each capture
save_gemlite_cache
()
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment