Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d49776b
Unverified
Commit
3d49776b
authored
Sep 29, 2024
by
Jee Jee Li
Committed by
GitHub
Sep 29, 2024
Browse files
[Model][LoRA]LoRA support added for MiniCPMV2.5 (#7199)
parent
bc2ef1f7
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
378 additions
and
31 deletions
+378
-31
tests/lora/conftest.py
tests/lora/conftest.py
+5
-0
tests/lora/test_minicpmv.py
tests/lora/test_minicpmv.py
+71
-0
tests/lora/test_minicpmv_tp.py
tests/lora/test_minicpmv_tp.py
+95
-0
vllm/lora/models.py
vllm/lora/models.py
+41
-4
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+72
-22
vllm/model_executor/models/module_mapping.py
vllm/model_executor/models/module_mapping.py
+69
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+20
-2
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+5
-3
No files found.
tests/lora/conftest.py
View file @
3d49776b
...
...
@@ -194,6 +194,11 @@ def baichuan_zero_lora_files():
return
snapshot_download
(
repo_id
=
"jeeejeee/baichuan7b-zero-init"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
minicpmv_lora_files
():
return
snapshot_download
(
repo_id
=
"jeeejeee/minicpmv25-lora-pokemon"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
tinyllama_lora_files
():
return
snapshot_download
(
repo_id
=
"jashing/tinyllama-colorist-lora"
)
...
...
tests/lora/test_minicpmv.py
0 → 100644
View file @
3d49776b
from
typing
import
List
import
vllm
from
vllm.assets.image
import
ImageAsset
from
vllm.lora.request
import
LoRARequest
MODEL_PATH
=
"openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
"
"(<image>./</image>)
\n
What is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
IMAGE_ASSETS
=
[
ImageAsset
(
"stop_sign"
),
ImageAsset
(
"cherry_blossom"
),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT
=
[
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents."
,
# noqa: E501
"A pink cherry blossom tree with a blue sky in the background."
,
]
def
do_sample
(
llm
:
vllm
.
LLM
,
lora_path
:
str
,
lora_id
:
int
)
->
List
[
str
]:
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
5
,
stop_token_ids
=
[
128001
,
128009
],
# eos_id, eot_id
)
inputs
=
[{
"prompt"
:
PROMPT_TEMPLATE
,
"multi_modal_data"
:
{
"image"
:
asset
.
pil_image
},
}
for
asset
in
IMAGE_ASSETS
]
outputs
=
llm
.
generate
(
inputs
,
sampling_params
,
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
,
)
# Print the outputs.
generated_texts
:
List
[
str
]
=
[]
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
def
test_minicpmv_lora
(
minicpmv_lora_files
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_num_seqs
=
2
,
enable_lora
=
True
,
max_loras
=
4
,
max_lora_rank
=
64
,
trust_remote_code
=
True
,
)
output1
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
assert
EXPECTED_OUTPUT
[
i
].
startswith
(
output1
[
i
])
output2
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
2
)
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
assert
EXPECTED_OUTPUT
[
i
].
startswith
(
output2
[
i
])
tests/lora/test_minicpmv_tp.py
0 → 100644
View file @
3d49776b
from
typing
import
List
import
pytest
import
vllm
from
vllm.assets.image
import
ImageAsset
from
vllm.lora.request
import
LoRARequest
from
..utils
import
multi_gpu_test
MODEL_PATH
=
"openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE
=
(
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>
\n\n
"
"(<image>./</image>)
\n
What is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
IMAGE_ASSETS
=
[
ImageAsset
(
"stop_sign"
),
ImageAsset
(
"cherry_blossom"
),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT
=
[
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents."
,
# noqa: E501
"A pink cherry blossom tree with a blue sky in the background."
,
]
def
do_sample
(
llm
:
vllm
.
LLM
,
lora_path
:
str
,
lora_id
:
int
)
->
List
[
str
]:
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
5
,
stop_token_ids
=
[
128001
,
128009
],
# eos_id, eot_id
)
inputs
=
[{
"prompt"
:
PROMPT_TEMPLATE
,
"multi_modal_data"
:
{
"image"
:
asset
.
pil_image
},
}
for
asset
in
IMAGE_ASSETS
]
outputs
=
llm
.
generate
(
inputs
,
sampling_params
,
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
,
)
# Print the outputs.
generated_texts
:
List
[
str
]
=
[]
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
.
strip
()
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"fully_sharded"
,
[
True
,
False
])
def
test_minicpmv_tp2
(
minicpmv_lora_files
,
fully_sharded
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
max_num_seqs
=
2
,
max_loras
=
4
,
max_lora_rank
=
64
,
tensor_parallel_size
=
2
,
trust_remote_code
=
True
,
fully_sharded_loras
=
fully_sharded
,
)
output_tp
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
assert
EXPECTED_OUTPUT
[
i
].
startswith
(
output_tp
[
i
])
@
multi_gpu_test
(
num_gpus
=
4
)
@
pytest
.
mark
.
parametrize
(
"fully_sharded"
,
[
True
,
False
])
def
test_minicpmv_tp4
(
minicpmv_lora_files
,
fully_sharded
):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
enable_lora
=
True
,
max_num_seqs
=
2
,
max_loras
=
4
,
max_lora_rank
=
64
,
tensor_parallel_size
=
4
,
trust_remote_code
=
True
,
fully_sharded_loras
=
fully_sharded
,
)
output_tp
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
assert
EXPECTED_OUTPUT
[
i
].
startswith
(
output_tp
[
i
])
vllm/lora/models.py
View file @
3d49776b
...
...
@@ -24,7 +24,9 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from
vllm.lora.punica
import
PunicaWrapper
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.models.interfaces
import
SupportsLoRA
from
vllm.model_executor.models.interfaces
import
(
SupportsLoRA
,
supports_multimodal
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.utils
import
PPMissingLayer
from
vllm.utils
import
is_pin_memory_available
...
...
@@ -332,6 +334,8 @@ class LoRAModelManager(AdapterModelManager):
self
.
supported_lora_modules
.
append
(
"rotary_emb"
)
self
.
packed_modules_mapping
=
copy
.
deepcopy
(
self
.
model
.
packed_modules_mapping
)
# Used to indicate whether the model is a multimodal model
self
.
supports_mm
:
bool
=
supports_multimodal
(
self
.
model
)
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
self
.
modules
:
Dict
[
str
,
"BaseLayerWithLoRA"
]
=
{}
# Dict instead of a Set for compatibility with LRUCache.
...
...
@@ -437,12 +441,22 @@ class LoRAModelManager(AdapterModelManager):
continue
if
not
self
.
_match_target_modules
(
module_name
):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if
self
.
_filter_unsupported_mm_module
(
module_name
):
logger
.
warning
(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored."
,
module_name
,
)
continue
parts
=
module_name
.
split
(
"."
)[
-
1
]
packed_moduled_lst
=
self
.
packed_modules_mapping
.
get
(
parts
,
[])
new_module
=
replace_submodule
(
self
.
model
,
module_name
,
from_layer
(
module
,
self
.
lora_slots
,
self
.
lora_config
,
packed_moduled_lst
,
self
.
model
.
config
))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if
isinstance
(
new_module
,
LinearScalingRotaryEmbeddingWithLora
):
...
...
@@ -460,6 +474,15 @@ class LoRAModelManager(AdapterModelManager):
module
,
self
.
lora_slots
,
self
.
lora_config
,
self
.
model
.
config
))
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if
self
.
supports_mm
and
not
isinstance
(
new_module
,
BaseLayerWithLoRA
):
continue
self
.
register_module
(
module_name
,
new_module
)
self
.
_register_packed_modules
(
module_name
)
# All lora layers share the same punica_wrapper based on reference.
...
...
@@ -478,9 +501,10 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup."""
model
=
LoRAModel
(
lora_id
,
rank
,
{},
scaling_factor
)
for
module_name
,
module
in
self
.
model
.
named_modules
():
if
not
self
.
_match_target_modules
(
module_name
)
or
not
isinstance
(
module
,
BaseLayerWithLoRA
)
or
isinstance
(
module
,
LinearScalingRotaryEmbeddingWithLora
):
if
(
not
self
.
_match_target_modules
(
module_name
)
or
not
isinstance
(
module
,
BaseLayerWithLoRA
)
or
isinstance
(
module
,
LinearScalingRotaryEmbeddingWithLora
)
or
self
.
_filter_unsupported_mm_module
(
module_name
)):
continue
parts
=
module_name
.
split
(
"."
)
if
module_name
not
in
self
.
packed_modules
:
...
...
@@ -541,6 +565,19 @@ class LoRAModelManager(AdapterModelManager):
module_name
)
or
target_module
==
module_name
for
target_module
in
self
.
supported_lora_modules
)
def
_filter_unsupported_mm_module
(
self
,
module_name
:
str
)
->
bool
:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if
self
.
supports_mm
:
prefix
=
module_name
.
split
(
"."
)[
0
]
module_mapping
:
MultiModelKeys
=
self
.
model
.
get_mm_mapping
()
return
(
prefix
in
module_mapping
.
connector
or
prefix
in
module_mapping
.
tower_model
)
return
False
def
_register_packed_modules
(
self
,
module_full_name
:
str
)
->
None
:
parts
=
module_full_name
.
split
(
"."
)
module_name
=
parts
[
-
1
]
...
...
vllm/model_executor/models/minicpmv.py
View file @
3d49776b
...
...
@@ -36,7 +36,7 @@ from transformers import PretrainedConfig
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
...
@@ -50,7 +50,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.minicpm
import
MiniCPMModel
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.models.utils
import
LLMWrapper
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
...
...
@@ -59,10 +61,10 @@ from vllm.multimodal.utils import cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
SupportsLoRA
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
"llm.model"
:
"llm"
,
}
...
...
@@ -621,6 +623,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"llm"
,
connector
=
"resampler"
,
tower_model
=
"vpm"
)
def
init_llm
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -669,9 +679,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
MiniCPMModel
(
config
,
return
LLMWrapper
(
MiniCPMModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
),
name
=
"model"
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
# TODO :refactor this vision model
...
...
@@ -697,6 +709,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return
model
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_tokens
(
input_ids
)
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
resampler
=
Resampler2
(
...
...
@@ -743,7 +758,34 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return
"resampler"
in
name
or
"vpm"
in
name
class
MiniCPMV2_5
(
MiniCPMVBaseModel
):
class
MiniCPMV2_5
(
MiniCPMVBaseModel
,
SupportsLoRA
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
# vision encoder
"fc1"
,
"fc2"
,
"out_proj"
,
# language model
"qkv_proj"
,
# same name with vision encoder
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
# resampler
"kv_proj"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
self
,
...
...
@@ -751,6 +793,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
):
super
().
__init__
(
config
,
multimodal_config
,
cache_config
,
quant_config
)
assert
self
.
version
==
(
2
,
5
)
...
...
@@ -761,9 +804,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
LlamaModel
(
config
,
return
LLMWrapper
(
LlamaModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
),
name
=
"model"
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
model
=
Idefics2VisionTransformer
(
self
.
config
.
vision_config
)
...
...
@@ -843,9 +887,11 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
nn
.
Module
:
return
Qwen2Model
(
config
,
return
LLMWrapper
(
Qwen2Model
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
),
name
=
"model"
)
def
init_vision_module
(
self
)
->
nn
.
Module
:
# A custom version of SiglipVisionTransformer, won't work with TP
...
...
@@ -870,7 +916,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
)
return
resampler
def
get_vision_embedding
(
...
...
@@ -934,20 +979,25 @@ _SUPPORT_VERSION = {
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_minicpmv_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_minicpmv
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_minicpmv
)
class
MiniCPMV
(
MiniCPMVBaseModel
):
class
MiniCPMV
(
MiniCPMVBaseModel
,
SupportsLoRA
):
"""
Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
"""
def
__new__
(
cls
,
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping
=
{}
supported_lora_modules
=
[]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__new__
(
cls
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
lora_config
:
Optional
[
LoRAConfig
]
=
None
):
if
not
hasattr
(
config
,
"version"
):
if
config
.
hidden_size
==
2304
and
config
.
query_num
==
64
:
version
=
(
2
,
0
)
...
...
vllm/model_executor/models/module_mapping.py
0 → 100644
View file @
3d49776b
# Adapted from
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Union
@
dataclass
class
ModelKeys
:
model_type
:
str
=
None
module_list
:
str
=
None
embedding
:
str
=
None
mlp
:
str
=
None
down_proj
:
str
=
None
attention
:
str
=
None
o_proj
:
str
=
None
q_proj
:
str
=
None
k_proj
:
str
=
None
v_proj
:
str
=
None
qkv_proj
:
str
=
None
qk_proj
:
str
=
None
qa_proj
:
str
=
None
qb_proj
:
str
=
None
kva_proj
:
str
=
None
kvb_proj
:
str
=
None
output
:
str
=
None
@
dataclass
class
MultiModelKeys
(
ModelKeys
):
language_model
:
List
[
str
]
=
field
(
default_factory
=
list
)
connector
:
List
[
str
]
=
field
(
default_factory
=
list
)
# vision tower and audio tower
tower_model
:
List
[
str
]
=
field
(
default_factory
=
list
)
generator
:
List
[
str
]
=
field
(
default_factory
=
list
)
@
staticmethod
def
from_string_field
(
language_model
:
Union
[
str
,
List
[
str
]]
=
None
,
connector
:
Union
[
str
,
List
[
str
]]
=
None
,
tower_model
:
Union
[
str
,
List
[
str
]]
=
None
,
generator
:
Union
[
str
,
List
[
str
]]
=
None
,
**
kwargs
)
->
'MultiModelKeys'
:
def
to_list
(
value
):
if
value
is
None
:
return
[]
return
[
value
]
if
isinstance
(
value
,
str
)
else
list
(
value
)
return
MultiModelKeys
(
language_model
=
to_list
(
language_model
),
connector
=
to_list
(
connector
),
tower_model
=
to_list
(
tower_model
),
generator
=
to_list
(
generator
),
**
kwargs
)
vllm/model_executor/models/utils.py
View file @
3d49776b
import
itertools
from
collections
import
UserDict
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
import
torch
import
torch.nn
as
nn
...
...
@@ -329,3 +329,21 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
})
return
make_empty_intermediate_tensors
class
LLMWrapper
(
nn
.
Module
):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""
def
__init__
(
self
,
llm
:
nn
.
Module
,
name
:
str
)
->
None
:
super
().
__init__
()
self
.
model_name
=
name
setattr
(
self
,
name
,
llm
)
def
forward
(
self
,
*
args
,
**
kwargs
)
->
Any
:
return
getattr
(
self
,
self
.
model_name
)(
*
args
,
**
kwargs
)
def
embed_tokens
(
self
,
*
args
,
**
kwargs
)
->
Any
:
return
getattr
(
self
,
self
.
model_name
).
embed_tokens
(
*
args
,
**
kwargs
)
vllm/worker/model_runner.py
View file @
3d49776b
...
...
@@ -1034,10 +1034,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
model_memory_usage
/
float
(
2
**
30
))
if
self
.
lora_config
:
assert
supports_lora
(
self
.
model
),
"Model does not support LoRA"
assert
not
supports_multimodal
(
assert
supports_lora
(
self
.
model
),
"To be tested: Multi-modal model with LoRA settings."
),
f
"
{
self
.
model
.
__class__
.
__name__
}
does not support LoRA yet."
if
supports_multimodal
(
self
.
model
):
logger
.
warning
(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model."
)
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
scheduler_config
.
max_num_seqs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment