"src/vscode:/vscode.git/clone" did not exist on "bf5ca036fa7fbd6b46dc67df76d782eb90a860ca"
Unverified Commit 76619261 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: update torch 2.5.1 (#2069)

parent 2a3992b6
.PHONY: check-deps install-deps format
check-deps:
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
install-deps:
pip install isort black
format: check-deps
@echo "Formatting modified Python files..."
git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}'
...@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu ...@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
"outlines>=0.0.44,<0.1.0", "modelscope"] "outlines>=0.0.44,<0.1.0", "modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl # => base docker rocm/vllm-dev:20241022, not from public vllm whl
......
...@@ -32,12 +32,14 @@ from vllm.distributed import ( ...@@ -32,12 +32,14 @@ from vllm.distributed import (
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_custom_op("sglang_silu_and_mul")
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
...@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp): ...@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
return out return out
@register_custom_op("sglang_gelu_and_mul")
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"): def __init__(self, approximate="tanh"):
super().__init__() super().__init__()
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from vllm.model_executor.custom_op import CustomOp
def register_custom_op(op_name):
def decorator(cls):
if hasattr(CustomOp, "register"):
return CustomOp.register(op_name)(cls)
else:
return cls
return decorator
...@@ -33,9 +33,12 @@ if is_flashinfer_available(): ...@@ -33,9 +33,12 @@ if is_flashinfer_available():
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_custom_op("sglang_rmsnorm")
class RMSNorm(CustomOp): class RMSNorm(CustomOp):
def __init__( def __init__(
self, self,
...@@ -78,6 +81,7 @@ class RMSNorm(CustomOp): ...@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
return x, residual return x, residual
@register_custom_op("sglang_gemma_rmsnorm")
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
def __init__( def __init__(
self, self,
......
...@@ -90,6 +90,8 @@ def set_torch_compile_config(): ...@@ -90,6 +90,8 @@ def set_torch_compile_config():
# FIXME: tmp workaround # FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024 torch._dynamo.config.accumulated_cache_size_limit = 1024
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
@maybe_torch_compile(dynamic=True) @maybe_torch_compile(dynamic=True)
......
...@@ -18,9 +18,9 @@ limitations under the License. ...@@ -18,9 +18,9 @@ limitations under the License.
import gc import gc
import importlib import importlib
import importlib.resources import importlib.resources
import inspect
import json import json
import logging import logging
import os
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Optional, Type
...@@ -60,6 +60,7 @@ from sglang.srt.utils import ( ...@@ -60,6 +60,7 @@ from sglang.srt.utils import (
crash_on_warnings, crash_on_warnings,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
) )
...@@ -226,6 +227,47 @@ class ModelRunner: ...@@ -226,6 +227,47 @@ class ModelRunner:
return min_per_gpu_memory return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self): def load_model(self):
logger.info( logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
...@@ -247,31 +289,15 @@ class ModelRunner: ...@@ -247,31 +289,15 @@ class ModelRunner:
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
) )
self.vllm_model_config = VllmModelConfig( monkey_patch_vllm_model_config()
model=self.server_args.model_path, self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
if self.model_config.model_override_args is not None: if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update( self.vllm_model_config.hf_config.update(
self.model_config.model_override_args self.model_config.model_override_args
) )
# Load the model self.model = self.setup_model()
self.model = get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
self.sliding_window_size = ( self.sliding_window_size = (
self.model.get_attention_sliding_window_size() self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
...@@ -303,17 +329,9 @@ class ModelRunner: ...@@ -303,17 +329,9 @@ class ModelRunner:
target_device = torch.device(self.device) target_device = torch.device(self.device)
try: try:
# TODO: Use a better method to check this model_config_params = self.get_model_config_params()
vllm_model_config = VllmModelConfig( model_config_params["model"] = model_path
model=model_path, vllm_model_config = VllmModelConfig(**model_config_params)
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
except Exception as e: except Exception as e:
message = f"Failed to load model config: {e}." message = f"Failed to load model config: {e}."
return False, message return False, message
......
...@@ -332,6 +332,7 @@ def suppress_other_loggers(): ...@@ -332,6 +332,7 @@ def suppress_other_loggers():
) )
logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR) logging.getLogger("vllm.utils").setLevel(logging.ERROR)
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
warnings.filterwarnings( warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable" "ignore", category=UserWarning, message="The given NumPy array is not writable"
...@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None): ...@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
pass pass
def monkey_patch_vllm_model_config():
from vllm.config import ModelConfig
if not hasattr(ModelConfig, "_resolve_task"):
return
def _resolve_task(
self,
task_option,
hf_config,
):
supported_tasks = {
"generate": True,
"embedding": False,
}
selected_task = "generate"
return supported_tasks, selected_task
setattr(ModelConfig, "_resolve_task", _resolve_task)
def monkey_patch_vllm_p2p_access_check(gpu_id: int): def monkey_patch_vllm_p2p_access_check(gpu_id: int):
""" """
Monkey patch the slow p2p access check in vllm. Monkey patch the slow p2p access check in vllm.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import argparse import argparse
import asyncio import asyncio
import copy
import os import os
import random import random
import subprocess import subprocess
...@@ -529,6 +530,7 @@ def run_bench_serving( ...@@ -529,6 +530,7 @@ def run_bench_serving(
random_input_len=4096, random_input_len=4096,
random_output_len=2048, random_output_len=2048,
disable_stream=False, disable_stream=False,
need_warmup=False,
): ):
# Launch the server # Launch the server
base_url = DEFAULT_URL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST
...@@ -565,6 +567,10 @@ def run_bench_serving( ...@@ -565,6 +567,10 @@ def run_bench_serving(
) )
try: try:
if need_warmup:
warmup_args = copy.deepcopy(args)
warmup_args.num_prompts = 16
run_benchmark(warmup_args)
res = run_benchmark(args) res = run_benchmark(args)
finally: finally:
kill_child_process(process.pid, include_self=True) kill_child_process(process.pid, include_self=True)
......
...@@ -32,6 +32,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -32,6 +32,7 @@ class TestBenchServing(unittest.TestCase):
random_input_len=None, random_input_len=None,
random_output_len=None, random_output_len=None,
disable_stream=True, disable_stream=True,
need_warmup=True,
) )
if is_in_ci(): if is_in_ci():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment