Unverified Commit 0212d2e2 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[Fix] use `torch.inference_mode()` instead of `torch.no_grad()` (#4372)

parent 8cc300f5
...@@ -101,6 +101,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs ...@@ -101,6 +101,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
DynamicGradMode,
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
crash_on_warnings, crash_on_warnings,
...@@ -487,7 +488,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -487,7 +488,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
}, },
) )
@torch.no_grad() @DynamicGradMode()
def event_loop_normal(self): def event_loop_normal(self):
"""A normal scheduler loop.""" """A normal scheduler loop."""
while True: while True:
...@@ -507,7 +508,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -507,7 +508,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.last_batch = batch self.last_batch = batch
@torch.no_grad() @DynamicGradMode()
def event_loop_overlap(self): def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation.""" """A scheduler loop that overlaps the CPU processing and GPU computation."""
self.result_queue = deque() self.result_queue = deque()
......
...@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -115,7 +115,7 @@ class TpModelWorkerClient: ...@@ -115,7 +115,7 @@ class TpModelWorkerClient:
logger.error(f"TpModelWorkerClient hit an exception: {traceback}") logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
self.parent_process.send_signal(signal.SIGQUIT) self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad() @DynamicGradMode()
def forward_thread_func_(self): def forward_thread_func_(self):
batch_pt = 0 batch_pt = 0
batch_lists = [None] * 2 batch_lists = [None] * 2
......
...@@ -61,6 +61,7 @@ from torch import nn ...@@ -61,6 +61,7 @@ from torch import nn
from torch.func import functional_call from torch.func import functional_call
from torch.library import Library from torch.library import Library
from torch.profiler import ProfilerActivity, profile, record_function from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils._contextlib import _DecoratorContextManager
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
from triton.runtime.cache import ( from triton.runtime.cache import (
FileCacheManager, FileCacheManager,
...@@ -127,6 +128,63 @@ def is_cuda_available(): ...@@ -127,6 +128,63 @@ def is_cuda_available():
return is_cuda() return is_cuda()
_ENABLE_TORCH_INFERENCE_MODE = os.getenv(
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
).lower() in ("true", "1")
class DynamicGradMode(_DecoratorContextManager):
"""
A combination of torch.no_grad and torch.inference_mode,
with their behavior controlled by an environment variable. Just refer to them.
"""
@staticmethod
def set_inference_mode(mode: bool):
if isinstance(mode, bool):
global _ENABLE_TORCH_INFERENCE_MODE
_ENABLE_TORCH_INFERENCE_MODE = mode
else:
logger.warning("mode is not a boolean object")
def __init__(self, mode=True):
if not torch._jit_internal.is_scripting():
super().__init__()
if _ENABLE_TORCH_INFERENCE_MODE:
self.mode = mode
else:
self.prev = False
def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None):
if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool):
return super().__new__(cls)
return cls()(mode_or_orig_func)
def __enter__(self) -> None:
if _ENABLE_TORCH_INFERENCE_MODE:
self._inference_mode_context = torch._C._InferenceMode(self.mode)
self._inference_mode_context.__enter__()
else:
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if _ENABLE_TORCH_INFERENCE_MODE:
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
else:
torch.set_grad_enabled(self.prev)
def clone(self) -> "DynamicGradMode":
r"""
Create a copy of this class
"""
if _ENABLE_TORCH_INFERENCE_MODE:
return self.__class__(self.mode)
else:
return self.__class__()
def enable_show_time_cost(): def enable_show_time_cost():
global show_time_cost global show_time_cost
show_time_cost = True show_time_cost = True
......
import unittest
import torch
from sglang.srt.utils import DynamicGradMode
class TestDynamicGradMode(unittest.TestCase):
def test_inference(self):
# Test inference_mode
DynamicGradMode.set_inference_mode(True)
@DynamicGradMode()
def create_tensor_x():
return torch.empty(0)
X = create_tensor_x()
self.assertTrue(not X.requires_grad and X.is_inference())
def test_no_grad(self):
# Test no_grad
DynamicGradMode.set_inference_mode(False)
@DynamicGradMode()
def create_tensor_y():
return torch.empty(0)
Y = create_tensor_y()
self.assertTrue(not Y.requires_grad and not Y.is_inference())
def test_nested_inference(self):
# Test no_grad nested inference_mode, inference_mode should has higher priority
DynamicGradMode.set_inference_mode(False)
@DynamicGradMode()
def create_tensor_z():
with torch.inference_mode():
return torch.empty(0)
Z = create_tensor_z()
self.assertTrue(not Z.requires_grad and Z.is_inference())
def test_nested_no_grad(self):
# Test inference_mode nested no_grad, inference_mode should has higher priority
DynamicGradMode.set_inference_mode(True)
@DynamicGradMode()
def create_tensor_w():
with torch.no_grad():
return torch.empty(0)
W = create_tensor_w()
self.assertTrue(not W.requires_grad and W.is_inference())
if __name__ == "__main__":
unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment