Unverified Commit 0abbf289 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Unify the model type checking (#1905)

parent c17c5781
...@@ -50,7 +50,7 @@ if __name__ == "__main__": ...@@ -50,7 +50,7 @@ if __name__ == "__main__":
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b") runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b")
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct") runtime.endpoint.chat_template = get_chat_template("llama-3-instruct-llava")
# Or you can use the 72B model # Or you can use the 72B model
# runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8) # runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8)
......
...@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank): ...@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
model_config = ModelConfig( model_config = ModelConfig(
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args), model_override_args=server_args.json_model_override_args,
) )
model_runner = ModelRunner( model_runner = ModelRunner(
model_config=model_config, model_config=model_config,
......
...@@ -116,12 +116,10 @@ register_chat_template( ...@@ -116,12 +116,10 @@ register_chat_template(
) )
) )
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="qwen", name="chatml-llava",
default_system_prompt="You are a helpful assistant.", default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"), "system": ("<|im_start|>system\n", "<|im_end|>\n"),
...@@ -130,13 +128,17 @@ register_chat_template( ...@@ -130,13 +128,17 @@ register_chat_template(
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",), stop_str=("<|im_end|>",),
image_token="<image>\n",
) )
) )
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
# There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="qwen2-vl", name="qwen",
default_system_prompt="You are a helpful assistant.", default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"), "system": ("<|im_start|>system\n", "<|im_end|>\n"),
...@@ -145,14 +147,13 @@ register_chat_template( ...@@ -145,14 +147,13 @@ register_chat_template(
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",), stop_str=("<|im_end|>",),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
) )
) )
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="chatml-llava", name="qwen2-vl",
default_system_prompt="You are a helpful assistant.", default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"), "system": ("<|im_start|>system\n", "<|im_end|>\n"),
...@@ -161,7 +162,7 @@ register_chat_template( ...@@ -161,7 +162,7 @@ register_chat_template(
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",), stop_str=("<|im_end|>",),
image_token="<image>\n", image_token="<|vision_start|><|image_pad|><|vision_end|>",
) )
) )
...@@ -182,37 +183,46 @@ register_chat_template( ...@@ -182,37 +183,46 @@ register_chat_template(
) )
) )
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="yi-1.5", name="llama-2-chat",
default_system_prompt=None, default_system_prompt=None,
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ("", ""), "system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"), "user": ("[INST] ", " [/INST]"),
"assistant": ("", "<|im_end|>\n"), "assistant": ("", " </s><s>"),
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.LLAMA2,
stop_str=("<|im_end|>",),
) )
) )
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="llama-2-chat", name="llama-3-instruct",
default_system_prompt=None, default_system_prompt=None,
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"), "system": (
"user": ("[INST] ", " [/INST]"), "<|start_header_id|>system<|end_header_id|>\n\n",
"assistant": ("", " </s><s>"), "<|eot_id|>",
),
"user": (
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|eot_id|>",
),
"assistant": (
"<|start_header_id|>assistant<|end_header_id|>\n\n",
"<|eot_id|>",
),
}, },
style=ChatTemplateStyle.LLAMA2, stop_str=("<|eot_id|>",),
image_token="<|image|>",
) )
) )
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="llama-3-instruct", name="llama-3-instruct-llava",
default_system_prompt=None, default_system_prompt=None,
role_prefix_and_suffix={ role_prefix_and_suffix={
"system": ( "system": (
...@@ -229,7 +239,22 @@ register_chat_template( ...@@ -229,7 +239,22 @@ register_chat_template(
), ),
}, },
stop_str=("<|eot_id|>",), stop_str=("<|eot_id|>",),
image_token="<|image|>", image_token="<image>\n",
)
)
# Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
register_chat_template(
ChatTemplate(
name="yi-1.5",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
"assistant": ("", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
) )
) )
......
...@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import json
import logging import logging
import os import os
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Optional from typing import List, Optional
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -38,18 +39,24 @@ class ModelConfig: ...@@ -38,18 +39,24 @@ class ModelConfig:
revision: Optional[str] = None, revision: Optional[str] = None,
context_length: Optional[int] = None, context_length: Optional[int] = None,
model_override_args: Optional[dict] = None, model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None
) -> None: ) -> None:
self.path = path # Parse args
self.trust_remote_code = trust_remote_code self.model_override_args = json.loads(model_override_args)
self.revision = revision
self.model_override_args = model_override_args
self.hf_config = get_config( self.hf_config = get_config(
self.path, path,
trust_remote_code, trust_remote_code=trust_remote_code,
revision, revision=revision,
model_override_args=model_override_args, model_override_args=self.model_override_args,
) )
self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_text_config = get_hf_text_config(self.hf_config)
# Check model type
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config) derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get( allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
...@@ -81,7 +88,7 @@ class ModelConfig: ...@@ -81,7 +88,7 @@ class ModelConfig:
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads, self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
) )
# FIXME: temporary special judge for deepseek v2 MLA architecture # FIXME: temporary special judge for MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures: if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256 self.head_dim = 256
self.attention_arch = AttentionArch.MLA self.attention_arch = AttentionArch.MLA
...@@ -112,8 +119,6 @@ class ModelConfig: ...@@ -112,8 +119,6 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads.""" """Returns the total number of KV heads."""
...@@ -163,7 +168,6 @@ class ModelConfig: ...@@ -163,7 +168,6 @@ class ModelConfig:
# equal to the number of attention heads. # equal to the number of attention heads.
return self.hf_text_config.num_attention_heads return self.hf_text_config.num_attention_heads
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
def get_num_kv_heads(self, tensor_parallel_size) -> int: def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU.""" """Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads() total_num_kv_heads = self.get_total_num_kv_heads()
...@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig): ...@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config return config.text_config
else: else:
return config return config
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
return not is_embedding
def is_multimodal_model(model_architectures: List[str]):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
else:
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
...@@ -180,7 +180,7 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -180,7 +180,7 @@ class LlavaImageProcessor(BaseImageProcessor):
"pixel_values": pixel_values, "pixel_values": pixel_values,
"image_hashes": image_hashes, "image_hashes": image_hashes,
"image_sizes": image_sizes, "image_sizes": image_sizes,
"modalities": request_obj.modalities, "modalities": request_obj.modalities or ["image"],
} }
......
...@@ -15,7 +15,6 @@ limitations under the License. ...@@ -15,7 +15,6 @@ limitations under the License.
"""A scheduler that manages a tensor parallel GPU worker.""" """A scheduler that manages a tensor parallel GPU worker."""
import json
import logging import logging
import os import os
import threading import threading
...@@ -23,7 +22,7 @@ import time ...@@ -23,7 +22,7 @@ import time
import warnings import warnings
from collections import deque from collections import deque
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional, Union from typing import List, Optional
import torch import torch
import zmq import zmq
...@@ -68,8 +67,6 @@ from sglang.srt.utils import ( ...@@ -68,8 +67,6 @@ from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
get_zmq_socket, get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_parent_process, kill_parent_process,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
...@@ -133,15 +130,17 @@ class Scheduler: ...@@ -133,15 +130,17 @@ class Scheduler:
# Init tokenizer # Init tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args), model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
) )
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
if is_multimodal_model(self.model_config.hf_config.architectures): if self.model_config.is_multimodal:
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
...@@ -154,9 +153,6 @@ class Scheduler: ...@@ -154,9 +153,6 @@ class Scheduler:
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
) )
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.enable_overlap: if self.enable_overlap:
......
...@@ -18,7 +18,6 @@ limitations under the License. ...@@ -18,7 +18,6 @@ limitations under the License.
import asyncio import asyncio
import copy import copy
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import signal import signal
...@@ -31,12 +30,8 @@ import zmq ...@@ -31,12 +30,8 @@ import zmq
import zmq.asyncio import zmq.asyncio
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
from sglang.srt.hf_transformers_utils import ( from sglang.srt.configs.model_config import ModelConfig
get_config, from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.managers.image_processor import ( from sglang.srt.managers.image_processor import (
get_dummy_image_processor, get_dummy_image_processor,
get_image_processor, get_image_processor,
...@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -59,12 +54,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import get_zmq_socket, kill_child_process
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -103,18 +93,17 @@ class TokenizerManager: ...@@ -103,18 +93,17 @@ class TokenizerManager:
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name self.served_model_name = server_args.served_model_name
self.hf_config = get_config( self.model_config = ModelConfig(
self.model_path, server_args.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_override_args=json.loads(server_args.json_model_override_args), context_length=server_args.context_length,
) model_override_args=server_args.json_model_override_args,
self.is_generation = is_generation_model( is_embedding=server_args.is_embedding,
self.hf_config.architectures, self.server_args.is_embedding
)
self.context_len = server_args.context_length or get_context_length(
self.hf_config
) )
self.is_generation = self.model_config.is_generation
self.context_len = self.model_config.context_len
# Create image processor placeholder # Create image processor placeholder
self.image_processor = get_dummy_image_processor() self.image_processor = get_dummy_image_processor()
...@@ -122,7 +111,7 @@ class TokenizerManager: ...@@ -122,7 +111,7 @@ class TokenizerManager:
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
if is_multimodal_model(self.hf_config.architectures): if self.model_config.is_multimodal:
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
...@@ -133,7 +122,7 @@ class TokenizerManager: ...@@ -133,7 +122,7 @@ class TokenizerManager:
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor( self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor self.model_config.hf_config, server_args, self.processor
) )
else: else:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
......
...@@ -15,7 +15,6 @@ limitations under the License. ...@@ -15,7 +15,6 @@ limitations under the License.
"""A tensor parallel worker.""" """A tensor parallel worker."""
import json
import logging import logging
from typing import Optional from typing import Optional
...@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a ...@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed from sglang.srt.utils import broadcast_pyobj, set_random_seed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,9 +47,10 @@ class TpModelWorker: ...@@ -48,9 +47,10 @@ class TpModelWorker:
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=json.loads(server_args.json_model_override_args), model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
...@@ -64,7 +64,7 @@ class TpModelWorker: ...@@ -64,7 +64,7 @@ class TpModelWorker:
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
if is_multimodal_model(self.model_config.hf_config.architectures): if self.model_config.is_multimodal:
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
......
...@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs ...@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
is_attention_free_model,
is_embedding_model,
is_generation_model,
is_multimodal_model,
model_has_inner_state,
monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check, monkey_patch_vllm_p2p_access_check,
) )
...@@ -93,9 +88,8 @@ class ModelRunner: ...@@ -93,9 +88,8 @@ class ModelRunner:
self.tp_size = tp_size self.tp_size = tp_size
self.dist_port = nccl_port self.dist_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_multimodal_model = is_multimodal_model( self.is_generation = model_config.is_generation
self.model_config.hf_config.architectures self.is_multimodal = model_config.is_multimodal
)
# Model-specific adjustment # Model-specific adjustment
if ( if (
...@@ -119,7 +113,7 @@ class ModelRunner: ...@@ -119,7 +113,7 @@ class ModelRunner:
self.server_args.ds_heavy_channel_type self.server_args.ds_heavy_channel_type
) )
if self.is_multimodal_model: if self.is_multimodal:
logger.warning( logger.warning(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
) )
...@@ -270,9 +264,6 @@ class ModelRunner: ...@@ -270,9 +264,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
else None else None
) )
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
logger.info( logger.info(
f"Load weight end. " f"Load weight end. "
...@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: ...@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
# Monkey patch model loader # Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt) setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model) setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model) setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state) setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", is_embedding_model) setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
...@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module): ...@@ -409,11 +409,13 @@ class LlamaForCausalLM(nn.Module):
if ( if (
hasattr(self.config, "tie_word_embeddings") hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
): ):
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight) weight_loader(param, self.model.embed_tokens.weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"])) apply_torchao_config_(self, params_dict, set(["proj.weight"]))
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np import numpy as np
import torch import torch
...@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils ...@@ -36,7 +36,6 @@ from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
...@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
cached_get_processor = lru_cache(get_processor) cached_get_processor = lru_cache(get_processor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): class Qwen2VLForConditionalGeneration(nn.Module):
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]): def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path) processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw grid_t, grid_h, grid_w = image_grid_thw
...@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__( def __init__(
self, self,
config: Qwen2VLConfig, config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer( self.visual = Qwen2VisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
......
...@@ -204,56 +204,6 @@ def is_port_available(port): ...@@ -204,56 +204,6 @@ def is_port_available(port):
return False return False
def is_multimodal_model(model_architectures):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
else:
return False
def is_attention_free_model(model_architectures):
return False
def model_has_inner_state(model_architectures):
return False
def is_embedding_model(model_architectures):
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return True
else:
return False
def is_generation_model(model_architectures, is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
return not is_embedding
def decode_video_base64(video_base64): def decode_video_base64(video_base64):
from PIL import Image from PIL import Image
......
import asyncio import asyncio
import json
import unittest import unittest
import openai import openai
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment