Unverified Commit 0769b14b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Move torch.compile patch to a better place (#5397)

parent b64b88e7
...@@ -34,6 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -34,6 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import get_available_gpu_memory, is_hip from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -108,6 +109,8 @@ def set_torch_compile_config(): ...@@ -108,6 +109,8 @@ def set_torch_compile_config():
if hasattr(torch._dynamo.config, "cache_size_limit"): if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024 torch._dynamo.config.cache_size_limit = 1024
monkey_patch_torch_compile()
def get_batch_sizes_to_capture(model_runner: ModelRunner): def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args server_args = model_runner.server_args
......
...@@ -64,10 +64,7 @@ from sglang.srt.model_loader.loader import ( ...@@ -64,10 +64,7 @@ from sglang.srt.model_loader.loader import (
) )
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.patch_torch import ( from sglang.srt.patch_torch import monkey_patch_torch_reductions
monkey_patch_torch_compile,
monkey_patch_torch_reductions,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...@@ -94,8 +91,6 @@ logger = logging.getLogger(__name__) ...@@ -94,8 +91,6 @@ logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
monkey_patch_torch_compile()
class ModelRunner: class ModelRunner:
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment