Unverified Commit 3a581e99 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Dynamic model class loading (#101)

parent 0147f940
...@@ -20,7 +20,7 @@ dependencies = [ ...@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba", "zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle"] "pydantic", "diskcache", "cloudpickle", "pillow"]
openai = ["openai>=1.0", "numpy"] openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"] anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
import importlib
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from functools import lru_cache
from pathlib import Path
from typing import List from typing import List
import numpy as np import numpy as np
import torch import torch
import sglang
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model from sglang.srt.utils import is_multimodal_model
...@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner") ...@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
global_model_mode: List[str] = [] global_model_mode: List[str] = []
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
if hasattr(module, "EntryClass"):
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
return model_arch_name_to_cls
def get_model_cls_by_arch_name(model_arch_names):
model_arch_name_to_cls = import_model_classes()
model_class = None
for arch in model_arch_names:
if arch in model_arch_name_to_cls:
model_class = model_arch_name_to_cls[arch]
break
else:
raise ValueError(
f"Unsupported architectures: {arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_class
@dataclass @dataclass
class InputMetadata: class InputMetadata:
model_runner: "ModelRunner" model_runner: "ModelRunner"
...@@ -237,34 +266,9 @@ class ModelRunner: ...@@ -237,34 +266,9 @@ class ModelRunner:
def load_model(self): def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model""" """See also vllm/model_executor/model_loader.py::get_model"""
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM
from sglang.srt.models.qwen import QWenLMHeadModel
# Select model class # Select model class
architectures = getattr(self.model_config.hf_config, "architectures", []) architectures = getattr(self.model_config.hf_config, "architectures", [])
model_class = get_model_cls_by_arch_name(architectures)
model_class = None
for arch in architectures:
if arch == "LlamaForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "MistralForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "LlavaLlamaForCausalLM":
model_class = LlavaLlamaForCausalLM
break
if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM
break
if arch == "QWenLMHeadModel":
model_class = QWenLMHeadModel
break
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")
logger.info(f"Rank {self.tp_rank}: load weight begin.") logger.info(f"Rank {self.tp_rank}: load weight begin.")
# Load weights # Load weights
......
...@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module): ...@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = LlamaForCausalLM
...@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward(): ...@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
"forward", "forward",
clip_vision_embed_forward, clip_vision_embed_forward,
) )
EntryClass = LlavaLlamaForCausalLM
...@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module): ...@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = MixtralForCausalLM
...@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module): ...@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = QWenLMHeadModel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment