Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98cf2ed6
Unverified
Commit
98cf2ed6
authored
Jun 28, 2024
by
Cyrus Leung
Committed by
GitHub
Jun 27, 2024
Browse files
[Model][Bugfix] Implicit model flags and reenable Phi-3-Vision (#5896)
parent
e9d32d07
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
26 additions
and
32 deletions
+26
-32
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+0
-2
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+0
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+0
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+0
-2
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+16
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+0
-2
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+0
-2
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+0
-2
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+0
-2
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+0
-2
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+0
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+10
-6
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+0
-2
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+0
-2
No files found.
vllm/model_executor/models/baichuan.py
View file @
98cf2ed6
...
...
@@ -295,8 +295,6 @@ class BaiChuanModel(nn.Module):
class
BaiChuanBaseForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"W_pack"
:
[
"W_pack"
],
"gate_up_proj"
:
[
...
...
vllm/model_executor/models/chatglm.py
View file @
98cf2ed6
...
...
@@ -325,8 +325,6 @@ class ChatGLMModel(nn.Module):
class
ChatGLMForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
...
...
vllm/model_executor/models/gemma.py
View file @
98cf2ed6
...
...
@@ -291,8 +291,6 @@ class GemmaModel(nn.Module):
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
98cf2ed6
...
...
@@ -233,8 +233,6 @@ class GPTBigCodeModel(nn.Module):
class
GPTBigCodeForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"c_attn"
:
[
"c_attn"
]}
supported_lora_modules
=
[
"c_fc"
,
"c_proj"
,
"wte"
,
"lm_head"
,
"c_attn"
]
...
...
vllm/model_executor/models/interfaces.py
View file @
98cf2ed6
...
...
@@ -13,7 +13,14 @@ logger = init_logger(__name__)
class
SupportsVision
(
Protocol
):
"""The interface required for all vision language models (VLMs)."""
supports_vision
:
ClassVar
[
Literal
[
True
]]
supports_vision
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports vision inputs.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def
__init__
(
self
,
*
,
vlm_config
:
VisionLanguageConfig
)
->
None
:
...
...
...
@@ -52,7 +59,14 @@ def supports_vision(
class
SupportsLoRA
(
Protocol
):
"""The interface required for all models that support LoRA."""
supports_lora
:
ClassVar
[
Literal
[
True
]]
supports_lora
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports LoRA.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
packed_modules_mapping
:
ClassVar
[
Dict
[
str
,
List
[
str
]]]
supported_lora_modules
:
ClassVar
[
List
[
str
]]
...
...
vllm/model_executor/models/llama.py
View file @
98cf2ed6
...
...
@@ -299,8 +299,6 @@ class LlamaModel(nn.Module):
class
LlamaForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/llava.py
View file @
98cf2ed6
...
...
@@ -88,8 +88,6 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@
MULTIMODAL_REGISTRY
.
register_dummy_data
(
get_dummy_image_data
)
class
LlavaForConditionalGeneration
(
nn
.
Module
,
SupportsVision
):
supports_vision
=
True
def
__init__
(
self
,
config
:
LlavaConfig
,
vlm_config
:
VisionLanguageConfig
,
...
...
vllm/model_executor/models/llava_next.py
View file @
98cf2ed6
...
...
@@ -108,8 +108,6 @@ def _image_pixel_processor(
@
MULTIMODAL_REGISTRY
.
register_dummy_data
(
_get_dummy_image_data
)
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
SupportsVision
):
supports_vision
=
True
def
__init__
(
self
,
config
:
LlavaNextConfig
,
vlm_config
:
VisionLanguageConfig
,
...
...
vllm/model_executor/models/minicpm.py
View file @
98cf2ed6
...
...
@@ -392,8 +392,6 @@ class MiniCPMModel(nn.Module):
class
MiniCPMForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/mixtral.py
View file @
98cf2ed6
...
...
@@ -475,8 +475,6 @@ class MixtralModel(nn.Module):
class
MixtralForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
...
...
vllm/model_executor/models/phi.py
View file @
98cf2ed6
...
...
@@ -232,8 +232,6 @@ class PhiModel(nn.Module):
class
PhiForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/phi3v.py
View file @
98cf2ed6
...
...
@@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
ImagePixelData
,
get_dummy_image_data
from
vllm.sequence
import
SamplerOutput
from
.interfaces
import
SupportsVision
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
...
...
@@ -317,18 +318,21 @@ def _image_processor(
@
MULTIMODAL_REGISTRY
.
register_image_pixel_input
(
_image_processor
)
@
MULTIMODAL_REGISTRY
.
register_dummy_data
(
get_dummy_image_data
)
class
Phi3VForCausalLM
(
VisionLanguageModelBase
):
class
Phi3VForCausalLM
(
nn
.
Module
,
SupportsVision
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
v
ision_language
_config
:
VisionLanguageConfig
,
v
lm
_config
:
VisionLanguageConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
(
vision_language_config
)
super
().
__init__
()
self
.
config
=
config
self
.
vlm_config
=
vlm_config
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
)
self
.
vision_embed_tokens
=
Phi3HDImageEmbedding
(
v
ision_language
_config
,
config
,
self
.
model
.
embed_tokens
)
v
lm
_config
,
config
,
self
.
model
.
embed_tokens
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
...
...
@@ -338,7 +342,7 @@ class Phi3VForCausalLM(VisionLanguageModelBase):
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_sizes
=
kwargs
.
pop
(
"image_sizes"
,
None
)
expected_input_type
=
self
.
v
ision_language
_config
.
image_input_type
expected_input_type
=
self
.
v
lm
_config
.
image_input_type
ImageInputType
=
VisionLanguageConfig
.
ImageInputType
if
expected_input_type
!=
ImageInputType
.
PIXEL_VALUES
:
...
...
vllm/model_executor/models/qwen2.py
View file @
98cf2ed6
...
...
@@ -266,8 +266,6 @@ class Qwen2Model(nn.Module):
class
Qwen2ForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/xverse.py
View file @
98cf2ed6
...
...
@@ -269,8 +269,6 @@ class XverseModel(nn.Module):
class
XverseForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment