Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0212d2e2
Unverified
Commit
0212d2e2
authored
Mar 17, 2025
by
JieXin Liang
Committed by
GitHub
Mar 16, 2025
Browse files
[Fix] use `torch.inference_mode()` instead of `torch.no_grad()` (#4372)
parent
8cc300f5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
4 deletions
+120
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-2
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+2
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+58
-0
python/sglang/test/test_dynamic_grad_mode.py
python/sglang/test/test_dynamic_grad_mode.py
+57
-0
No files found.
python/sglang/srt/managers/scheduler.py
View file @
0212d2e2
...
@@ -101,6 +101,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
...
@@ -101,6 +101,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
DynamicGradMode
,
broadcast_pyobj
,
broadcast_pyobj
,
configure_logger
,
configure_logger
,
crash_on_warnings
,
crash_on_warnings
,
...
@@ -487,7 +488,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -487,7 +488,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
},
},
)
)
@
torch
.
no_grad
()
@
DynamicGradMode
()
def
event_loop_normal
(
self
):
def
event_loop_normal
(
self
):
"""A normal scheduler loop."""
"""A normal scheduler loop."""
while
True
:
while
True
:
...
@@ -507,7 +508,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
...
@@ -507,7 +508,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self
.
last_batch
=
batch
self
.
last_batch
=
batch
@
torch
.
no_grad
()
@
DynamicGradMode
()
def
event_loop_overlap
(
self
):
def
event_loop_overlap
(
self
):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
self
.
result_queue
=
deque
()
self
.
result_queue
=
deque
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
0212d2e2
...
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.srt.utils
import
DynamicGradMode
,
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
...
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
logger
.
error
(
f
"TpModelWorkerClient hit an exception:
{
traceback
}
"
)
logger
.
error
(
f
"TpModelWorkerClient hit an exception:
{
traceback
}
"
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
self
.
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
@
torch
.
no_grad
()
@
DynamicGradMode
()
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
batch_pt
=
0
batch_pt
=
0
batch_lists
=
[
None
]
*
2
batch_lists
=
[
None
]
*
2
...
...
python/sglang/srt/utils.py
View file @
0212d2e2
...
@@ -61,6 +61,7 @@ from torch import nn
...
@@ -61,6 +61,7 @@ from torch import nn
from
torch.func
import
functional_call
from
torch.func
import
functional_call
from
torch.library
import
Library
from
torch.library
import
Library
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
torch.utils._contextlib
import
_DecoratorContextManager
from
torch.utils.cpp_extension
import
CUDA_HOME
from
torch.utils.cpp_extension
import
CUDA_HOME
from
triton.runtime.cache
import
(
from
triton.runtime.cache
import
(
FileCacheManager
,
FileCacheManager
,
...
@@ -127,6 +128,63 @@ def is_cuda_available():
...
@@ -127,6 +128,63 @@ def is_cuda_available():
return
is_cuda
()
return
is_cuda
()
_ENABLE_TORCH_INFERENCE_MODE
=
os
.
getenv
(
"SGLANG_ENABLE_TORCH_INFERENCE_MODE"
,
"false"
).
lower
()
in
(
"true"
,
"1"
)
class
DynamicGradMode
(
_DecoratorContextManager
):
"""
A combination of torch.no_grad and torch.inference_mode,
with their behavior controlled by an environment variable. Just refer to them.
"""
@
staticmethod
def
set_inference_mode
(
mode
:
bool
):
if
isinstance
(
mode
,
bool
):
global
_ENABLE_TORCH_INFERENCE_MODE
_ENABLE_TORCH_INFERENCE_MODE
=
mode
else
:
logger
.
warning
(
"mode is not a boolean object"
)
def
__init__
(
self
,
mode
=
True
):
if
not
torch
.
_jit_internal
.
is_scripting
():
super
().
__init__
()
if
_ENABLE_TORCH_INFERENCE_MODE
:
self
.
mode
=
mode
else
:
self
.
prev
=
False
def
__new__
(
cls
,
mode_or_orig_func
=
True
if
_ENABLE_TORCH_INFERENCE_MODE
else
None
):
if
mode_or_orig_func
is
None
or
isinstance
(
mode_or_orig_func
,
bool
):
return
super
().
__new__
(
cls
)
return
cls
()(
mode_or_orig_func
)
def
__enter__
(
self
)
->
None
:
if
_ENABLE_TORCH_INFERENCE_MODE
:
self
.
_inference_mode_context
=
torch
.
_C
.
_InferenceMode
(
self
.
mode
)
self
.
_inference_mode_context
.
__enter__
()
else
:
self
.
prev
=
torch
.
is_grad_enabled
()
torch
.
set_grad_enabled
(
False
)
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
if
_ENABLE_TORCH_INFERENCE_MODE
:
self
.
_inference_mode_context
.
__exit__
(
exc_type
,
exc_value
,
traceback
)
else
:
torch
.
set_grad_enabled
(
self
.
prev
)
def
clone
(
self
)
->
"DynamicGradMode"
:
r
"""
Create a copy of this class
"""
if
_ENABLE_TORCH_INFERENCE_MODE
:
return
self
.
__class__
(
self
.
mode
)
else
:
return
self
.
__class__
()
def
enable_show_time_cost
():
def
enable_show_time_cost
():
global
show_time_cost
global
show_time_cost
show_time_cost
=
True
show_time_cost
=
True
...
...
python/sglang/test/test_dynamic_grad_mode.py
0 → 100644
View file @
0212d2e2
import
unittest
import
torch
from
sglang.srt.utils
import
DynamicGradMode
class
TestDynamicGradMode
(
unittest
.
TestCase
):
def
test_inference
(
self
):
# Test inference_mode
DynamicGradMode
.
set_inference_mode
(
True
)
@
DynamicGradMode
()
def
create_tensor_x
():
return
torch
.
empty
(
0
)
X
=
create_tensor_x
()
self
.
assertTrue
(
not
X
.
requires_grad
and
X
.
is_inference
())
def
test_no_grad
(
self
):
# Test no_grad
DynamicGradMode
.
set_inference_mode
(
False
)
@
DynamicGradMode
()
def
create_tensor_y
():
return
torch
.
empty
(
0
)
Y
=
create_tensor_y
()
self
.
assertTrue
(
not
Y
.
requires_grad
and
not
Y
.
is_inference
())
def
test_nested_inference
(
self
):
# Test no_grad nested inference_mode, inference_mode should has higher priority
DynamicGradMode
.
set_inference_mode
(
False
)
@
DynamicGradMode
()
def
create_tensor_z
():
with
torch
.
inference_mode
():
return
torch
.
empty
(
0
)
Z
=
create_tensor_z
()
self
.
assertTrue
(
not
Z
.
requires_grad
and
Z
.
is_inference
())
def
test_nested_no_grad
(
self
):
# Test inference_mode nested no_grad, inference_mode should has higher priority
DynamicGradMode
.
set_inference_mode
(
True
)
@
DynamicGradMode
()
def
create_tensor_w
():
with
torch
.
no_grad
():
return
torch
.
empty
(
0
)
W
=
create_tensor_w
()
self
.
assertTrue
(
not
W
.
requires_grad
and
W
.
is_inference
())
if
__name__
==
"__main__"
:
unittest
.
main
(
verbosity
=
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment