Unverified Commit 76619261 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: update torch 2.5.1 (#2069)

parent 2a3992b6
.PHONY: check-deps install-deps format
check-deps:
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
install-deps:
pip install isort black
format: check-deps
@echo "Formatting modified Python files..."
git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}'
......@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "pyzmq>=25.1.2",
"outlines>=0.0.44,<0.1.0", "modelscope"]
srt = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post1"]
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
......
......@@ -32,12 +32,14 @@ from vllm.distributed import (
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs
logger = logging.getLogger(__name__)
@register_custom_op("sglang_silu_and_mul")
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
......@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
return out
@register_custom_op("sglang_gelu_and_mul")
class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):
super().__init__()
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from vllm.model_executor.custom_op import CustomOp
def register_custom_op(op_name):
def decorator(cls):
if hasattr(CustomOp, "register"):
return CustomOp.register(op_name)(cls)
else:
return cls
return decorator
......@@ -33,9 +33,12 @@ if is_flashinfer_available():
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
logger = logging.getLogger(__name__)
@register_custom_op("sglang_rmsnorm")
class RMSNorm(CustomOp):
def __init__(
self,
......@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
return x, residual
@register_custom_op("sglang_gemma_rmsnorm")
class GemmaRMSNorm(CustomOp):
def __init__(
self,
......
......@@ -90,6 +90,8 @@ def set_torch_compile_config():
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
@maybe_torch_compile(dynamic=True)
......
......@@ -18,9 +18,9 @@ limitations under the License.
import gc
import importlib
import importlib.resources
import inspect
import json
import logging
import os
import pkgutil
from functools import lru_cache
from typing import Optional, Type
......@@ -60,6 +60,7 @@ from sglang.srt.utils import (
crash_on_warnings,
enable_show_time_cost,
get_available_gpu_memory,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
)
......@@ -226,6 +227,47 @@ class ModelRunner:
return min_per_gpu_memory
def setup_model(self):
try:
from vllm.config import VllmConfig
vllm_config = VllmConfig()
vllm_config.model_config = self.vllm_model_config
vllm_config.load_config = self.load_config
vllm_config.device_config = DeviceConfig(self.device)
vllm_config.quant_config = VllmConfig._get_quantization_config(
vllm_config.model_config, vllm_config.load_config
)
return get_model(vllm_config=vllm_config)
except ImportError:
return get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
params = {
"model": self.server_args.model_path,
"quantization": self.server_args.quantization,
"tokenizer": None,
"tokenizer_mode": None,
"trust_remote_code": self.server_args.trust_remote_code,
"dtype": self.server_args.dtype,
"seed": self.server_args.random_seed,
"skip_tokenizer_init": True,
}
if "task" in sig.parameters:
params["task"] = ""
return params
def load_model(self):
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
......@@ -247,31 +289,15 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
monkey_patch_vllm_model_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
self.model_config.model_override_args
)
# Load the model
self.model = get_model(
model_config=self.vllm_model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
parallel_config=None,
scheduler_config=None,
lora_config=None,
cache_config=None,
)
self.model = self.setup_model()
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
......@@ -303,17 +329,9 @@ class ModelRunner:
target_device = torch.device(self.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=self.server_args.random_seed,
skip_tokenizer_init=True,
)
model_config_params = self.get_model_config_params()
model_config_params["model"] = model_path
vllm_model_config = VllmModelConfig(**model_config_params)
except Exception as e:
message = f"Failed to load model config: {e}."
return False, message
......
......@@ -332,6 +332,7 @@ def suppress_other_loggers():
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable"
......@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
pass
def monkey_patch_vllm_model_config():
from vllm.config import ModelConfig
if not hasattr(ModelConfig, "_resolve_task"):
return
def _resolve_task(
self,
task_option,
hf_config,
):
supported_tasks = {
"generate": True,
"embedding": False,
}
selected_task = "generate"
return supported_tasks, selected_task
setattr(ModelConfig, "_resolve_task", _resolve_task)
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
"""
Monkey patch the slow p2p access check in vllm.
......
......@@ -2,6 +2,7 @@
import argparse
import asyncio
import copy
import os
import random
import subprocess
......@@ -529,6 +530,7 @@ def run_bench_serving(
random_input_len=4096,
random_output_len=2048,
disable_stream=False,
need_warmup=False,
):
# Launch the server
base_url = DEFAULT_URL_FOR_TEST
......@@ -565,6 +567,10 @@ def run_bench_serving(
)
try:
if need_warmup:
warmup_args = copy.deepcopy(args)
warmup_args.num_prompts = 16
run_benchmark(warmup_args)
res = run_benchmark(args)
finally:
kill_child_process(process.pid, include_self=True)
......
......@@ -32,6 +32,7 @@ class TestBenchServing(unittest.TestCase):
random_input_len=None,
random_output_len=None,
disable_stream=True,
need_warmup=True,
)
if is_in_ci():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment