Unverified Commit 75235419 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

`model_rpc` style improvement (#293)

parent 64ee9c03
......@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import openai
import tiktoken
import openai
except ImportError as e:
openai = tiktoken = e
......
......@@ -6,7 +6,6 @@ import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import List
import numpy as np
import rpyc
import torch
from rpyc.utils.classic import obtain
......@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
logger = logging.getLogger("model_rpc")
class ModelRpcServer(rpyc.Service):
def exposed_init_model(
class ModelRpcServer:
def __init__(
self,
tp_rank: int,
server_args: ServerArgs,
......@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
batch.reqs = []
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
class ModelRpcClient:
def __init__(self, server_args: ServerArgs, port_args: PortArgs):
tp_size = server_args.tp_size
if tp_size == 1:
# Init model
self.model_server = ModelRpcServer()
self.model_server.exposed_init_model(0, server_args, port_args)
self.model_server = ModelRpcService().exposed_ModelRpcServer(
0, server_args, port_args
)
# Wrap functions
def async_wrap(f):
......@@ -629,14 +633,16 @@ class ModelRpcClient:
with ThreadPoolExecutor(tp_size) as executor:
# Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.model_servers = [x[0] for x in rets]
self.remote_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets]
# Init model
def init_model(i):
return self.model_servers[i].init_model(i, server_args, port_args)
return self.remote_services[i].ModelRpcServer(
i, server_args, port_args
)
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
self.model_servers = executor.map(init_model, range(tp_size))
# Wrap functions
def async_wrap(func_name):
......@@ -654,7 +660,7 @@ class ModelRpcClient:
def _init_service(port):
t = ThreadedServer(
ModelRpcServer(),
ModelRpcService(),
port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
)
......
import importlib
import logging
import importlib.resources
import inspect
import logging
import pkgutil
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
import importlib.resources
import numpy as np
import torch
......@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import importlib
import pkgutil
import sglang
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
logger = logging.getLogger("model_runner")
......@@ -37,7 +32,7 @@ def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
......@@ -144,9 +139,12 @@ class InputMetadata:
# flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
if (
len(inspect.signature(self.prefill_wrapper.begin_forward).parameters)
== 7
):
args.append(self.model_runner.model_config.head_dim)
self.prefill_wrapper.begin_forward(*args)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
......@@ -307,9 +305,11 @@ class ModelRunner:
hf_quant_method = hf_quant_config["quant_method"]
# compat: autogptq uses is_marlin_format within quant config
if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):
if (
hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]
):
hf_quant_method = "marlin"
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment