Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96354d6a
Unverified
Commit
96354d6a
authored
Jun 27, 2024
by
Cyrus Leung
Committed by
GitHub
Jun 27, 2024
Browse files
[Model] Add base class for LoRA-supported models (#5018)
parent
d12af207
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
270 additions
and
75 deletions
+270
-75
docs/source/models/lora.rst
docs/source/models/lora.rst
+3
-0
vllm/lora/lora.py
vllm/lora/lora.py
+2
-1
vllm/lora/models.py
vllm/lora/models.py
+3
-3
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+13
-7
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+9
-2
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+9
-2
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+2
-2
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+8
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+8
-1
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+130
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+8
-1
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+12
-10
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+11
-9
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+10
-2
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+8
-1
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+14
-8
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+8
-2
vllm/model_executor/models/vlm_base.py
vllm/model_executor/models/vlm_base.py
+0
-12
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+9
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-8
No files found.
docs/source/models/lora.rst
View file @
96354d6a
...
...
@@ -4,6 +4,9 @@ Using LoRA adapters
===================
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
them locally with
...
...
vllm/lora/lora.py
View file @
96354d6a
...
...
@@ -2,6 +2,7 @@ from typing import List, Optional
from
typing
import
Sequence
as
GenericSequence
import
torch
import
torch.types
from
vllm.utils
import
is_pin_memory_available
...
...
@@ -64,7 +65,7 @@ class LoRALayerWeights:
output_dim
:
int
,
rank
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
d
evice
,
device
:
torch
.
types
.
D
evice
,
embeddings_tensor_dim
:
Optional
[
int
]
=
None
)
->
"LoRALayerWeights"
:
pin_memory
=
str
(
device
)
==
"cpu"
and
is_pin_memory_available
()
lora_a
=
torch
.
zeros
([
input_dim
,
rank
],
...
...
vllm/lora/models.py
View file @
96354d6a
...
...
@@ -18,6 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.models.interfaces
import
SupportsLoRA
from
vllm.utils
import
LRUCache
,
is_pin_memory_available
logger
=
init_logger
(
__name__
)
...
...
@@ -363,7 +364,7 @@ class LoRAModelManager:
def
__init__
(
self
,
model
:
nn
.
Module
,
model
:
SupportsLoRA
,
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
vocab_size
:
int
,
...
...
@@ -411,7 +412,7 @@ class LoRAModelManager:
# embeddings_indices
self
.
indices_len
:
List
[
Optional
[
int
]]
=
[
None
]
*
4
self
.
model
:
nn
.
Module
=
model
self
.
model
=
model
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
self
.
supported_lora_modules
=
copy
.
deepcopy
(
self
.
model
.
supported_lora_modules
)
...
...
@@ -428,7 +429,6 @@ class LoRAModelManager:
self
.
_active_loras
:
Dict
[
int
,
None
]
=
{}
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
@
property
def
capacity
(
self
)
->
int
:
...
...
vllm/model_executor/model_loader/loader.py
View file @
96354d6a
...
...
@@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
from
vllm.model_executor.models.interfaces
import
(
supports_lora
,
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_tpu
...
...
@@ -64,12 +65,15 @@ def _get_quantization_config(
def
_get_model_initialization_kwargs
(
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
]
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
vlm_config
:
Optional
[
VisionLanguageConfig
],
)
->
Dict
[
str
,
Any
]:
"""Get extra kwargs for model initialization."""
extra_kwargs
:
Dict
[
str
,
Any
]
=
{}
if
hasattr
(
model_class
,
"supported_lora_modules"
):
if
supports_lora
(
model_class
):
# lora_config=None is used to disable LoRA
extra_kwargs
[
"lora_config"
]
=
lora_config
elif
lora_config
:
raise
ValueError
(
...
...
@@ -77,13 +81,15 @@ def _get_model_initialization_kwargs(
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github."
)
elif
issubclass
(
model_class
,
VisionLanguageModelBase
):
if
vision_language_config
is
None
:
if
supports_vision
(
model_class
):
if
vlm_config
is
None
:
raise
ValueError
(
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments."
)
extra_kwargs
[
"vision_language_config"
]
=
vision_language_config
extra_kwargs
[
"vlm_config"
]
=
vlm_config
return
extra_kwargs
...
...
vllm/model_executor/models/baichuan.py
View file @
96354d6a
...
...
@@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
.interfaces
import
SupportsLoRA
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
...
...
@@ -292,7 +294,9 @@ class BaiChuanModel(nn.Module):
return
hidden_states
class
BaiChuanBaseForCausalLM
(
nn
.
Module
):
class
BaiChuanBaseForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"W_pack"
:
[
"W_pack"
],
"gate_up_proj"
:
[
...
...
@@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
config
:
PretrainedConfig
,
position_embedding
:
str
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
cache_config
,
quant_config
)
...
...
vllm/model_executor/models/chatglm.py
View file @
96354d6a
...
...
@@ -28,6 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
.interfaces
import
SupportsLoRA
class
GLMAttention
(
nn
.
Module
):
...
...
@@ -322,7 +324,9 @@ class ChatGLMModel(nn.Module):
return
hidden_states
class
ChatGLMForCausalLM
(
nn
.
Module
):
class
ChatGLMForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"query_key_value"
:
[
"query_key_value"
],
"dense_h_to_4h"
:
[
"dense_h_to_4h"
]
...
...
@@ -345,7 +349,10 @@ class ChatGLMForCausalLM(nn.Module):
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
:
ChatGLMConfig
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
max_position_embeddings
=
getattr
(
config
,
"max_sequence_length"
,
8192
)
...
...
vllm/model_executor/models/decilm.py
View file @
96354d6a
...
...
@@ -26,7 +26,7 @@
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
transformers
import
Pretrained
Config
from
transformers
import
Llama
Config
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
config
:
LlamaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
...
...
vllm/model_executor/models/gemma.py
View file @
96354d6a
...
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
.interfaces
import
SupportsLoRA
logger
=
init_logger
(
__name__
)
...
...
@@ -288,7 +290,9 @@ class GemmaModel(nn.Module):
return
hidden_states
class
GemmaForCausalLM
(
nn
.
Module
):
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -319,9 +323,11 @@ class GemmaForCausalLM(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
GemmaModel
(
config
,
cache_config
,
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
96354d6a
...
...
@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
.interfaces
import
SupportsLoRA
class
GPTBigCodeAttention
(
nn
.
Module
):
...
...
@@ -230,7 +232,9 @@ class GPTBigCodeModel(nn.Module):
return
hidden_states
class
GPTBigCodeForCausalLM
(
nn
.
Module
):
class
GPTBigCodeForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"c_attn"
:
[
"c_attn"
]}
supported_lora_modules
=
[
"c_fc"
,
"c_proj"
,
"wte"
,
"lm_head"
,
"c_attn"
]
...
...
@@ -250,7 +254,10 @@ class GPTBigCodeForCausalLM(nn.Module):
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPTBigCodeModel
(
config
,
cache_config
,
quant_config
,
lora_config
)
...
...
vllm/model_executor/models/interfaces.py
0 → 100644
View file @
96354d6a
from
typing
import
(
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
from
typing_extensions
import
TypeGuard
from
vllm.config
import
LoRAConfig
,
VisionLanguageConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
runtime_checkable
class
SupportsVision
(
Protocol
):
"""The interface required for all vision language models (VLMs)."""
supports_vision
:
ClassVar
[
Literal
[
True
]]
def
__init__
(
self
,
*
,
vlm_config
:
VisionLanguageConfig
)
->
None
:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_SupportsVisionType
(
Protocol
):
supports_vision
:
Literal
[
True
]
def
__call__
(
self
,
*
,
vlm_config
:
VisionLanguageConfig
)
->
None
:
...
@
overload
def
supports_vision
(
model
:
Type
[
object
])
->
TypeGuard
[
Type
[
SupportsVision
]]:
...
@
overload
def
supports_vision
(
model
:
object
)
->
TypeGuard
[
SupportsVision
]:
...
def
supports_vision
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeGuard
[
Type
[
SupportsVision
]],
TypeGuard
[
SupportsVision
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsVisionType
)
return
isinstance
(
model
,
SupportsVision
)
@
runtime_checkable
class
SupportsLoRA
(
Protocol
):
"""The interface required for all models that support LoRA."""
supports_lora
:
ClassVar
[
Literal
[
True
]]
packed_modules_mapping
:
ClassVar
[
Dict
[
str
,
List
[
str
]]]
supported_lora_modules
:
ClassVar
[
List
[
str
]]
embedding_modules
:
ClassVar
[
Dict
[
str
,
str
]]
embedding_padding_modules
:
ClassVar
[
List
[
str
]]
# lora_config is None when LoRA is not enabled
def
__init__
(
self
,
*
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
)
->
None
:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_SupportsLoRAType
(
Protocol
):
supports_lora
:
Literal
[
True
]
packed_modules_mapping
:
Dict
[
str
,
List
[
str
]]
supported_lora_modules
:
List
[
str
]
embedding_modules
:
Dict
[
str
,
str
]
embedding_padding_modules
:
List
[
str
]
def
__call__
(
self
,
*
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
)
->
None
:
...
@
overload
def
supports_lora
(
model
:
Type
[
object
])
->
TypeGuard
[
Type
[
SupportsLoRA
]]:
...
@
overload
def
supports_lora
(
model
:
object
)
->
TypeGuard
[
SupportsLoRA
]:
...
def
supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeGuard
[
Type
[
SupportsLoRA
]],
TypeGuard
[
SupportsLoRA
]]:
result
=
_supports_lora
(
model
)
if
not
result
:
lora_attrs
=
(
"packed_modules_mapping"
,
"supported_lora_modules"
,
"embedding_modules"
,
"embedding_padding_modules"
,
)
missing_attrs
=
tuple
(
attr
for
attr
in
lora_attrs
if
not
hasattr
(
model
,
attr
))
if
getattr
(
model
,
"supports_lora"
,
False
):
if
missing_attrs
:
logger
.
warning
(
"The model (%s) sets `supports_lora=True`, "
"but is missing LoRA-specific attributes: %s"
,
model
,
missing_attrs
,
)
else
:
if
not
missing_attrs
:
logger
.
warning
(
"The model (%s) contains all LoRA-specific attributes, "
"but does not set `supports_lora=True`."
,
model
)
return
result
def
_supports_lora
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeGuard
[
Type
[
SupportsLoRA
]],
TypeGuard
[
SupportsLoRA
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsLoRAType
)
return
isinstance
(
model
,
SupportsLoRA
)
vllm/model_executor/models/llama.py
View file @
96354d6a
...
...
@@ -49,6 +49,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.utils
import
is_hip
,
print_warning_once
from
.interfaces
import
SupportsLoRA
class
LlamaMLP
(
nn
.
Module
):
...
...
@@ -296,7 +298,9 @@ class LlamaModel(nn.Module):
return
hidden_states
class
LlamaForCausalLM
(
nn
.
Module
):
class
LlamaForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -336,7 +340,10 @@ class LlamaForCausalLM(nn.Module):
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
,
...
...
vllm/model_executor/models/llava.py
View file @
96354d6a
...
...
@@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
get_dummy_image_data
from
vllm.sequence
import
SamplerOutput
from
.
vlm_base
import
VisionLanguageModelBase
from
.
interfaces
import
SupportsVision
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
...
...
@@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@
MULTIMODAL_REGISTRY
.
register_image_feature_input
()
@
MULTIMODAL_REGISTRY
.
register_image_pixel_input
()
@
MULTIMODAL_REGISTRY
.
register_dummy_data
(
get_dummy_image_data
)
class
LlavaForConditionalGeneration
(
VisionLanguageModelBase
):
class
LlavaForConditionalGeneration
(
nn
.
Module
,
SupportsVision
):
supports_vision
=
True
def
__init__
(
self
,
config
:
LlavaConfig
,
v
ision_language
_config
:
VisionLanguageConfig
,
v
lm
_config
:
VisionLanguageConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
(
vision_language_config
)
super
().
__init__
()
self
.
config
=
config
self
.
vlm_config
=
vlm_config
if
self
.
v
ision_language
_config
.
image_input_type
==
(
if
self
.
v
lm
_config
.
image_input_type
==
(
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
):
self
.
vision_tower
=
CLIPVisionModel
(
config
.
vision_config
)
else
:
...
...
@@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
self
.
sampler
=
Sampler
()
def
_validate_image_data
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
list
(
data
.
shape
[
1
:])
!=
list
(
self
.
vision_language_config
.
image_input_shape
[
1
:]):
if
list
(
data
.
shape
[
1
:])
!=
list
(
self
.
vlm_config
.
image_input_shape
[
1
:]):
raise
ValueError
(
f
"The expected image tensor shape is batch dimension plus "
f
"
{
self
.
v
ision_language
_config
.
image_input_shape
[
1
:]
}
. "
f
"
{
self
.
v
lm
_config
.
image_input_shape
[
1
:]
}
. "
f
"You supplied
{
data
.
shape
}
. "
f
"If you are using vLLM's entrypoint, make sure your "
f
"supplied image input is consistent with "
...
...
@@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_features
=
kwargs
.
pop
(
"image_features"
,
None
)
expected_input_type
=
self
.
v
ision_language
_config
.
image_input_type
expected_input_type
=
self
.
v
lm
_config
.
image_input_type
ImageInputType
=
VisionLanguageConfig
.
ImageInputType
if
expected_input_type
==
ImageInputType
.
PIXEL_VALUES
:
...
...
@@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
v
ision_language
_config
.
image_token_id
)
self
.
v
lm
_config
.
image_token_id
)
input_ids
=
None
else
:
...
...
vllm/model_executor/models/llava_next.py
View file @
96354d6a
...
...
@@ -25,8 +25,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from
vllm.multimodal.image
import
ImagePixelData
,
get_dummy_image_data
from
vllm.sequence
import
SamplerOutput
,
SequenceData
from
.interfaces
import
SupportsVision
from
.llava
import
LlavaMultiModalProjector
,
merge_vision_embeddings
from
.vlm_base
import
VisionLanguageModelBase
logger
=
init_logger
(
__name__
)
...
...
@@ -106,19 +106,21 @@ def _image_pixel_processor(
@
MULTIMODAL_REGISTRY
.
register_image_pixel_input
(
_image_pixel_processor
)
@
MULTIMODAL_REGISTRY
.
register_dummy_data
(
_get_dummy_image_data
)
class
LlavaNextForConditionalGeneration
(
VisionLanguageModelBase
):
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
SupportsVision
):
supports_vision
=
True
def
__init__
(
self
,
config
:
LlavaNextConfig
,
v
ision_language
_config
:
VisionLanguageConfig
,
v
lm
_config
:
VisionLanguageConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
(
vision_language_config
)
super
().
__init__
()
# Update the type annotation from that of its superclass
self
.
config
=
config
self
.
vlm_config
=
vlm_config
if
self
.
v
ision_language
_config
.
image_input_type
==
(
if
self
.
v
lm
_config
.
image_input_type
==
(
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
):
self
.
vision_tower
=
CLIPVisionModel
(
config
=
config
.
vision_config
)
else
:
...
...
@@ -146,7 +148,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
torch
.
empty
(
config
.
text_config
.
hidden_size
))
def
_validate_image_pixels
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
_
,
num_channels
,
_
,
_
=
self
.
v
ision_language
_config
.
image_input_shape
_
,
num_channels
,
_
,
_
=
self
.
v
lm
_config
.
image_input_shape
# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
...
...
@@ -177,7 +179,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
image_sizes
=
kwargs
.
pop
(
"image_sizes"
,
None
)
image_features
=
kwargs
.
pop
(
"image_features"
,
None
)
expected_input_type
=
self
.
v
ision_language
_config
.
image_input_type
expected_input_type
=
self
.
v
lm
_config
.
image_input_type
ImageInputType
=
VisionLanguageConfig
.
ImageInputType
if
expected_input_type
==
ImageInputType
.
PIXEL_VALUES
:
...
...
@@ -386,7 +388,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
v
ision_language
_config
.
image_token_id
)
self
.
v
lm
_config
.
image_token_id
)
input_ids
=
None
else
:
...
...
vllm/model_executor/models/minicpm.py
View file @
96354d6a
...
...
@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -51,6 +52,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
SamplerOutput
from
.interfaces
import
SupportsLoRA
class
MiniCPMMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation that shards each expert
...
...
@@ -388,7 +391,9 @@ class MiniCPMModel(nn.Module):
return
hidden_states
class
MiniCPMForCausalLM
(
nn
.
Module
):
class
MiniCPMForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -418,13 +423,16 @@ class MiniCPMForCausalLM(nn.Module):
def
__init__
(
self
,
config
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
quant_config
=
quant_config
self
.
model
=
MiniCPMModel
(
config
,
...
...
vllm/model_executor/models/mixtral.py
View file @
96354d6a
...
...
@@ -54,6 +54,8 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.sequence
import
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsLoRA
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
...
...
@@ -472,7 +474,9 @@ class MixtralModel(nn.Module):
return
hidden_states
class
MixtralForCausalLM
(
nn
.
Module
):
class
MixtralForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
...
...
@@ -504,7 +508,10 @@ class MixtralForCausalLM(nn.Module):
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
MixtralModel
(
config
,
cache_config
,
quant_config
,
...
...
vllm/model_executor/models/phi.py
View file @
96354d6a
...
...
@@ -39,7 +39,7 @@ from typing import Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
P
retrained
Config
from
transformers
import
P
hi
Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
...
...
@@ -59,11 +59,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
.interfaces
import
SupportsLoRA
class
PhiAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
P
retrained
Config
,
config
:
P
hi
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
...
...
@@ -131,7 +133,7 @@ class PhiAttention(nn.Module):
class
PhiMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
P
retrained
Config
,
config
:
P
hi
Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
...
...
@@ -160,7 +162,7 @@ class PhiMLP(nn.Module):
class
PhiLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
P
retrained
Config
,
config
:
P
hi
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
...
...
@@ -192,7 +194,7 @@ class PhiLayer(nn.Module):
class
PhiModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
P
retrained
Config
,
config
:
P
hi
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
...
...
@@ -229,7 +231,9 @@ class PhiModel(nn.Module):
return
hidden_states
class
PhiForCausalLM
(
nn
.
Module
):
class
PhiForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module):
def
__init__
(
self
,
config
:
P
retrained
Config
,
config
:
P
hi
Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
del
lora_config
# Unused.
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
PhiModel
(
config
,
cache_config
,
quant_config
)
...
...
vllm/model_executor/models/qwen2.py
View file @
96354d6a
...
...
@@ -48,6 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsLoRA
class
Qwen2MLP
(
nn
.
Module
):
...
...
@@ -263,7 +265,9 @@ class Qwen2Model(nn.Module):
return
hidden_states
class
Qwen2ForCausalLM
(
nn
.
Module
):
class
Qwen2ForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -293,7 +297,6 @@ class Qwen2ForCausalLM(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
del
lora_config
# TODO (@robertgshaw2): see if this can be moved out
if
(
cache_config
.
sliding_window
is
not
None
and
hasattr
(
config
,
"max_window_layers"
)):
...
...
@@ -307,7 +310,10 @@ class Qwen2ForCausalLM(nn.Module):
))
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
)
...
...
vllm/model_executor/models/vlm_base.py
deleted
100644 → 0
View file @
d12af207
from
torch
import
nn
from
vllm.config
import
VisionLanguageConfig
class
VisionLanguageModelBase
(
nn
.
Module
):
"""Base class for all vision language models (VLMs)."""
def
__init__
(
self
,
vision_language_config
:
VisionLanguageConfig
)
->
None
:
super
().
__init__
()
self
.
vision_language_config
=
vision_language_config
vllm/model_executor/models/xverse.py
View file @
96354d6a
...
...
@@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
.interfaces
import
SupportsLoRA
class
XverseMLP
(
nn
.
Module
):
...
...
@@ -266,7 +268,9 @@ class XverseModel(nn.Module):
return
hidden_states
class
XverseForCausalLM
(
nn
.
Module
):
class
XverseForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
supports_lora
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -299,10 +303,13 @@ class XverseForCausalLM(nn.Module):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
XverseModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
...
...
vllm/worker/model_runner.py
View file @
96354d6a
...
...
@@ -22,6 +22,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models.interfaces
import
supports_lora
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
...
...
@@ -225,14 +226,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
model_memory_usage
/
float
(
2
**
30
))
if
self
.
lora_config
:
assert
hasattr
(
self
.
model
,
"supported_lora_modules"
)
and
self
.
model
.
supported_lora_modules
,
(
"Model does not support LoRA"
)
assert
hasattr
(
self
.
model
,
"embedding_modules"
),
"Model does not have embedding_modules"
assert
hasattr
(
self
.
model
,
"embedding_padding_modules"
),
"Model does not have embedding_padding_modules"
assert
supports_lora
(
self
.
model
),
"Model does not support LoRA"
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment