Unverified Commit 3a581e99 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Dynamic model class loading (#101)

parent 0147f940
......@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle"]
"pydantic", "diskcache", "cloudpickle", "pillow"]
openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
import importlib
import logging
from dataclasses import dataclass
from enum import Enum, auto
from functools import lru_cache
from pathlib import Path
from typing import List
import numpy as np
import torch
import sglang
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
......@@ -20,6 +23,32 @@ logger = logging.getLogger("model_runner")
global_model_mode: List[str] = []
@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
if hasattr(module, "EntryClass"):
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
return model_arch_name_to_cls
def get_model_cls_by_arch_name(model_arch_names):
model_arch_name_to_cls = import_model_classes()
model_class = None
for arch in model_arch_names:
if arch in model_arch_name_to_cls:
model_class = model_arch_name_to_cls[arch]
break
else:
raise ValueError(
f"Unsupported architectures: {arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_class
@dataclass
class InputMetadata:
model_runner: "ModelRunner"
......@@ -237,34 +266,9 @@ class ModelRunner:
def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model"""
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM
from sglang.srt.models.qwen import QWenLMHeadModel
# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])
model_class = None
for arch in architectures:
if arch == "LlamaForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "MistralForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "LlavaLlamaForCausalLM":
model_class = LlavaLlamaForCausalLM
break
if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM
break
if arch == "QWenLMHeadModel":
model_class = QWenLMHeadModel
break
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")
model_class = get_model_cls_by_arch_name(architectures)
logger.info(f"Rank {self.tp_rank}: load weight begin.")
# Load weights
......
......@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = LlamaForCausalLM
......@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaLlamaForCausalLM
......@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = MixtralForCausalLM
......@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = QWenLMHeadModel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment