Unverified Commit 0abbf289 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Unify the model type checking (#1905)

parent c17c5781
......@@ -50,7 +50,7 @@ if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct-llava")
# Or you can use the 72B model
# runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)
......
......@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args),
model_override_args=server_args.json_model_override_args,
)
model_runner = ModelRunner(
model_config=model_config,
......
......@@ -116,12 +116,10 @@ register_chat_template(
)
)
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
register_chat_template(
ChatTemplate(
name="qwen",
name="chatml-llava",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
......@@ -130,13 +128,17 @@ register_chat_template(
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<image>\n",
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
register_chat_template(
ChatTemplate(
name="qwen2-vl",
name="qwen",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
......@@ -145,14 +147,13 @@ register_chat_template(
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_chat_template(
ChatTemplate(
name="chatml-llava",
name="qwen2-vl",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
......@@ -161,7 +162,7 @@ register_chat_template(
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<image>\n",
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
......@@ -182,37 +183,46 @@ register_chat_template(
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
name="yi-1.5",
name="llama-2-chat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
"assistant": ("", "<|im_end|>\n"),
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
style=ChatTemplateStyle.LLAMA2,
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
name="llama-3-instruct",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
"system": (
"<|start_header_id|>system<|end_header_id|>\n\n",
"<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
},
style=ChatTemplateStyle.LLAMA2,
stop_str=("<|eot_id|>",),
image_token="<|image|>",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
name="llama-3-instruct",
name="llama-3-instruct-llava",
default_system_prompt=None,
role_prefix_and_suffix={
"system": (
......@@ -229,7 +239,22 @@ register_chat_template(
),
},
stop_str=("<|eot_id|>",),
image_token="<|image|>",
image_token="<image>\n",
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
name="yi-1.5",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
"assistant": ("", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
......
......@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import os
from enum import IntEnum, auto
from typing import Optional
from typing import List, Optional
from transformers import PretrainedConfig
......@@ -38,18 +39,24 @@ class ModelConfig:
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None
) -> None:
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.model_override_args = model_override_args
# Parse args
self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config(
self.path,
trust_remote_code,
revision,
model_override_args=model_override_args,
path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
# Check model type
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
......@@ -81,7 +88,7 @@ class ModelConfig:
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
)
# FIXME: temporary special judge for deepseek v2 MLA architecture
# FIXME: temporary special judge for MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
......@@ -112,8 +119,6 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
......@@ -163,7 +168,6 @@ class ModelConfig:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
......@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config
else:
return config
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
return not is_embedding
def is_multimodal_model(model_architectures: List[str]):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
else:
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
......@@ -180,7 +180,7 @@ class LlavaImageProcessor(BaseImageProcessor):
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
"modalities": request_obj.modalities or ["image"],
}
......
......@@ -15,7 +15,6 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker."""
import json
import logging
import os
import threading
......@@ -23,7 +22,7 @@ import time
import warnings
from collections import deque
from types import SimpleNamespace
from typing import List, Optional, Union
from typing import List, Optional
import torch
import zmq
......@@ -68,8 +67,6 @@ from sglang.srt.utils import (
broadcast_pyobj,
configure_logger,
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_parent_process,
set_random_seed,
suppress_other_loggers,
......@@ -133,15 +130,17 @@ class Scheduler:
# Init tokenizer
self.model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args),
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.model_config.hf_config.architectures):
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
......@@ -154,9 +153,6 @@ class Scheduler:
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
# Launch a tensor parallel worker
if self.enable_overlap:
......
......@@ -18,7 +18,6 @@ limitations under the License.
import asyncio
import copy
import dataclasses
import json
import logging
import os
import signal
......@@ -31,12 +30,8 @@ import zmq
import zmq.asyncio
from fastapi import BackgroundTasks
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
......@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
from sglang.srt.utils import get_zmq_socket, kill_child_process
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -103,18 +93,17 @@ class TokenizerManager:
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.hf_config = get_config(
self.model_path,
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)
self.context_len = server_args.context_length or get_context_length(
self.hf_config
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
)
self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
......@@ -122,7 +111,7 @@ class TokenizerManager:
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.hf_config.architectures):
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
......@@ -133,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor
self.model_config.hf_config, server_args, self.processor
)
else:
self.tokenizer = get_tokenizer(
......
......@@ -15,7 +15,6 @@ limitations under the License.
"""A tensor parallel worker."""
import json
import logging
from typing import Optional
......@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
from sglang.srt.utils import broadcast_pyobj, set_random_seed
logger = logging.getLogger(__name__)
......@@ -48,9 +47,10 @@ class TpModelWorker:
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args),
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
......@@ -64,7 +64,7 @@ class TpModelWorker:
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.model_config.hf_config.architectures):
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
......
......@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
is_attention_free_model,
is_embedding_model,
is_generation_model,
is_multimodal_model,
model_has_inner_state,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
)
......@@ -93,9 +88,8 @@ class ModelRunner:
self.tp_size = tp_size
self.dist_port = nccl_port
self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal
# Model-specific adjustment
if (
......@@ -119,7 +113,7 @@ class ModelRunner:
self.server_args.ds_heavy_channel_type
)
if self.is_multimodal_model:
if self.is_multimodal:
logger.warning(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
......@@ -270,9 +264,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
logger.info(
f"Load weight end. "
......@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
......@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module):
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
......
......@@ -23,7 +23,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np
import torch
......@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
......@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
cached_get_processor = lru_cache(get_processor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
class Qwen2VLForConditionalGeneration(nn.Module):
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
......@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
......
......@@ -204,56 +204,6 @@ def is_port_available(port):
return False
def is_multimodal_model(model_architectures):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
else:
return False
def is_attention_free_model(model_architectures):
return False
def model_has_inner_state(model_architectures):
return False
def is_embedding_model(model_architectures):
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return True
else:
return False
def is_generation_model(model_architectures, is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
return not is_embedding
def decode_video_base64(video_base64):
from PIL import Image
......
import asyncio
import json
import unittest
import openai
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment