Unverified Commit 98cf2ed6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model][Bugfix] Implicit model flags and reenable Phi-3-Vision (#5896)

parent e9d32d07
...@@ -295,8 +295,6 @@ class BaiChuanModel(nn.Module): ...@@ -295,8 +295,6 @@ class BaiChuanModel(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"W_pack": ["W_pack"], "W_pack": ["W_pack"],
"gate_up_proj": [ "gate_up_proj": [
......
...@@ -325,8 +325,6 @@ class ChatGLMModel(nn.Module): ...@@ -325,8 +325,6 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module, SupportsLoRA): class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"query_key_value": ["query_key_value"], "query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"] "dense_h_to_4h": ["dense_h_to_4h"]
......
...@@ -291,8 +291,6 @@ class GemmaModel(nn.Module): ...@@ -291,8 +291,6 @@ class GemmaModel(nn.Module):
class GemmaForCausalLM(nn.Module, SupportsLoRA): class GemmaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -233,8 +233,6 @@ class GPTBigCodeModel(nn.Module): ...@@ -233,8 +233,6 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {"c_attn": ["c_attn"]} packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
......
...@@ -13,7 +13,14 @@ logger = init_logger(__name__) ...@@ -13,7 +13,14 @@ logger = init_logger(__name__)
class SupportsVision(Protocol): class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs).""" """The interface required for all vision language models (VLMs)."""
supports_vision: ClassVar[Literal[True]] supports_vision: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports vision inputs.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None: def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
... ...
...@@ -52,7 +59,14 @@ def supports_vision( ...@@ -52,7 +59,14 @@ def supports_vision(
class SupportsLoRA(Protocol): class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA.""" """The interface required for all models that support LoRA."""
supports_lora: ClassVar[Literal[True]] supports_lora: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports LoRA.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
packed_modules_mapping: ClassVar[Dict[str, List[str]]] packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]] supported_lora_modules: ClassVar[List[str]]
......
...@@ -299,8 +299,6 @@ class LlamaModel(nn.Module): ...@@ -299,8 +299,6 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module, SupportsLoRA): class LlamaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -88,8 +88,6 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] ...@@ -88,8 +88,6 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) @MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class LlavaForConditionalGeneration(nn.Module, SupportsVision): class LlavaForConditionalGeneration(nn.Module, SupportsVision):
supports_vision = True
def __init__(self, def __init__(self,
config: LlavaConfig, config: LlavaConfig,
vlm_config: VisionLanguageConfig, vlm_config: VisionLanguageConfig,
......
...@@ -108,8 +108,6 @@ def _image_pixel_processor( ...@@ -108,8 +108,6 @@ def _image_pixel_processor(
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data) @MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
supports_vision = True
def __init__(self, def __init__(self,
config: LlavaNextConfig, config: LlavaNextConfig,
vlm_config: VisionLanguageConfig, vlm_config: VisionLanguageConfig,
......
...@@ -392,8 +392,6 @@ class MiniCPMModel(nn.Module): ...@@ -392,8 +392,6 @@ class MiniCPMModel(nn.Module):
class MiniCPMForCausalLM(nn.Module, SupportsLoRA): class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -475,8 +475,6 @@ class MixtralModel(nn.Module): ...@@ -475,8 +475,6 @@ class MixtralModel(nn.Module):
class MixtralForCausalLM(nn.Module, SupportsLoRA): class MixtralForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
......
...@@ -232,8 +232,6 @@ class PhiModel(nn.Module): ...@@ -232,8 +232,6 @@ class PhiModel(nn.Module):
class PhiForCausalLM(nn.Module, SupportsLoRA): class PhiForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead ...@@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from .interfaces import SupportsVision
logger = init_logger(__name__) logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
...@@ -317,18 +318,21 @@ def _image_processor( ...@@ -317,18 +318,21 @@ def _image_processor(
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor) @MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data) @MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class Phi3VForCausalLM(VisionLanguageModelBase): class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
vision_language_config: VisionLanguageConfig, vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config) super().__init__()
self.config = config self.config = config
self.vlm_config = vlm_config
self.model = LlamaModel(config, cache_config, quant_config) self.model = LlamaModel(config, cache_config, quant_config)
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
vision_language_config, config, self.model.embed_tokens) vlm_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -338,7 +342,7 @@ class Phi3VForCausalLM(VisionLanguageModelBase): ...@@ -338,7 +342,7 @@ class Phi3VForCausalLM(VisionLanguageModelBase):
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None) image_sizes = kwargs.pop("image_sizes", None)
expected_input_type = self.vision_language_config.image_input_type expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type != ImageInputType.PIXEL_VALUES: if expected_input_type != ImageInputType.PIXEL_VALUES:
......
...@@ -266,8 +266,6 @@ class Qwen2Model(nn.Module): ...@@ -266,8 +266,6 @@ class Qwen2Model(nn.Module):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA): class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -269,8 +269,6 @@ class XverseModel(nn.Module): ...@@ -269,8 +269,6 @@ class XverseModel(nn.Module):
class XverseForCausalLM(nn.Module, SupportsLoRA): class XverseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment