Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a879811c
Unverified
Commit
a879811c
authored
Apr 10, 2025
by
Richard Zou
Committed by
GitHub
Apr 10, 2025
Browse files
Fix torch.compile cacheing (#5259)
Co-authored-by:
zhyncs
<
me@zhyncs.com
>
parent
a222945d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
1 deletion
+17
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-1
python/sglang/srt/patch_torch.py
python/sglang/srt/patch_torch.py
+11
-0
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
a879811c
...
...
@@ -64,7 +64,10 @@ from sglang.srt.model_loader.loader import (
)
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.patch_torch
import
(
monkey_patch_torch_compile
,
monkey_patch_torch_reductions
,
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
...
...
@@ -88,6 +91,8 @@ logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE
=
os
.
getenv
(
"SGLANG_CI_SMALL_KV_SIZE"
,
None
)
UNBALANCED_MODEL_LOADING_TIMEOUT_S
=
300
monkey_patch_torch_compile
()
class
ModelRunner
:
"""ModelRunner runs the forward passes of the models."""
...
...
python/sglang/srt/patch_torch.py
View file @
a879811c
...
...
@@ -14,6 +14,7 @@
from
typing
import
Callable
,
Union
import
torch
from
packaging
import
version
from
torch.multiprocessing
import
reductions
...
...
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
def
_modify_tuple
(
t
,
index
:
int
,
modifier
:
Callable
):
return
*
t
[:
index
],
modifier
(
t
[
index
]),
*
t
[
index
+
1
:]
def
monkey_patch_torch_compile
():
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.8.0"
):
# These things are cacheable by torch.compile. torch.compile just doesn't know it.
# This was fixed in PyTorch 2.8, but until then, we monkey patch.
import
torch._higher_order_ops.auto_functionalize
as
af
af
.
auto_functionalized_v2
.
_cacheable
=
True
af
.
auto_functionalized
.
_cacheable
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment