"vscode:/vscode.git/clone" did not exist on "e335f05fb15fec92b523f28fc4d9f019a35b7e75"
Unverified Commit 75235419 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

`model_rpc` style improvement (#293)

parent 64ee9c03
...@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor ...@@ -9,8 +9,9 @@ from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams from sglang.lang.ir import SglSamplingParams
try: try:
import openai
import tiktoken import tiktoken
import openai
except ImportError as e: except ImportError as e:
openai = tiktoken = e openai = tiktoken = e
......
...@@ -6,7 +6,6 @@ import warnings ...@@ -6,7 +6,6 @@ import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List from typing import List
import numpy as np
import rpyc import rpyc
import torch import torch
from rpyc.utils.classic import obtain from rpyc.utils.classic import obtain
...@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler ...@@ -36,8 +35,8 @@ from vllm.logger import _default_handler as vllm_default_handler
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
class ModelRpcServer(rpyc.Service): class ModelRpcServer:
def exposed_init_model( def __init__(
self, self,
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
...@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service): ...@@ -608,14 +607,19 @@ class ModelRpcServer(rpyc.Service):
batch.reqs = [] batch.reqs = []
class ModelRpcService(rpyc.Service):
exposed_ModelRpcServer = ModelRpcServer
class ModelRpcClient: class ModelRpcClient:
def __init__(self, server_args: ServerArgs, port_args: PortArgs): def __init__(self, server_args: ServerArgs, port_args: PortArgs):
tp_size = server_args.tp_size tp_size = server_args.tp_size
if tp_size == 1: if tp_size == 1:
# Init model # Init model
self.model_server = ModelRpcServer() self.model_server = ModelRpcService().exposed_ModelRpcServer(
self.model_server.exposed_init_model(0, server_args, port_args) 0, server_args, port_args
)
# Wrap functions # Wrap functions
def async_wrap(f): def async_wrap(f):
...@@ -629,14 +633,16 @@ class ModelRpcClient: ...@@ -629,14 +633,16 @@ class ModelRpcClient:
with ThreadPoolExecutor(tp_size) as executor: with ThreadPoolExecutor(tp_size) as executor:
# Launch model processes # Launch model processes
rets = executor.map(start_model_process, port_args.model_rpc_ports) rets = executor.map(start_model_process, port_args.model_rpc_ports)
self.model_servers = [x[0] for x in rets] self.remote_services = [x[0] for x in rets]
self.procs = [x[1] for x in rets] self.procs = [x[1] for x in rets]
# Init model # Init model
def init_model(i): def init_model(i):
return self.model_servers[i].init_model(i, server_args, port_args) return self.remote_services[i].ModelRpcServer(
i, server_args, port_args
)
rets = [obtain(x) for x in executor.map(init_model, range(tp_size))] self.model_servers = executor.map(init_model, range(tp_size))
# Wrap functions # Wrap functions
def async_wrap(func_name): def async_wrap(func_name):
...@@ -654,7 +660,7 @@ class ModelRpcClient: ...@@ -654,7 +660,7 @@ class ModelRpcClient:
def _init_service(port): def _init_service(port):
t = ThreadedServer( t = ThreadedServer(
ModelRpcServer(), ModelRpcService(),
port=port, port=port,
protocol_config={"allow_pickle": True, "sync_request_timeout": 1800}, protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
) )
......
import importlib import importlib
import logging import importlib.resources
import inspect import inspect
import logging
import pkgutil
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from pathlib import Path
import importlib.resources
import numpy as np import numpy as np
import torch import torch
...@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig ...@@ -18,11 +18,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
import importlib
import pkgutil
import sglang
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig} QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
logger = logging.getLogger("model_runner") logger = logging.getLogger("model_runner")
...@@ -37,7 +32,7 @@ def import_model_classes(): ...@@ -37,7 +32,7 @@ def import_model_classes():
model_arch_name_to_cls = {} model_arch_name_to_cls = {}
package_name = "sglang.srt.models" package_name = "sglang.srt.models"
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'): for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg: if not ispkg:
module = importlib.import_module(name) module = importlib.import_module(name)
if hasattr(module, "EntryClass"): if hasattr(module, "EntryClass"):
...@@ -144,9 +139,12 @@ class InputMetadata: ...@@ -144,9 +139,12 @@ class InputMetadata:
# flashinfer >= 0.0.3 # flashinfer >= 0.0.3
# FIXME: Drop this when flashinfer updates to 0.0.4 # FIXME: Drop this when flashinfer updates to 0.0.4
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7: if (
len(inspect.signature(self.prefill_wrapper.begin_forward).parameters)
== 7
):
args.append(self.model_runner.model_config.head_dim) args.append(self.model_runner.model_config.head_dim)
self.prefill_wrapper.begin_forward(*args) self.prefill_wrapper.begin_forward(*args)
else: else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
...@@ -307,9 +305,11 @@ class ModelRunner: ...@@ -307,9 +305,11 @@ class ModelRunner:
hf_quant_method = hf_quant_config["quant_method"] hf_quant_method = hf_quant_config["quant_method"]
# compat: autogptq uses is_marlin_format within quant config # compat: autogptq uses is_marlin_format within quant config
if (hf_quant_method == "gptq" if (
and "is_marlin_format" in hf_quant_config hf_quant_method == "gptq"
and hf_quant_config["is_marlin_format"]): and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]
):
hf_quant_method = "marlin" hf_quant_method = "marlin"
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method) quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment