Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f7030df3
Unverified
Commit
f7030df3
authored
Apr 11, 2025
by
Jee Jee Li
Committed by
GitHub
Apr 11, 2025
Browse files
[Core][LoRA][1/N] Add LoRA for EncoderDecoderModelRunner (#15990)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
905e91e9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
1 deletion
+41
-1
vllm/lora/layers.py
vllm/lora/layers.py
+5
-0
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+11
-0
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+25
-1
No files found.
vllm/lora/layers.py
View file @
f7030df3
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
...
@@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
and
len
(
packed_modules_list
)
==
3
)
and
len
(
packed_modules_list
)
==
3
)
#TODO: Implement this
class
QKVCrossParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
pass
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
class
RowParallelLinearWithLoRA
(
BaseLinearLayerWithLoRA
):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
...
...
vllm/model_executor/models/mllama.py
View file @
f7030df3
...
@@ -52,6 +52,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -52,6 +52,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
...
@@ -1181,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1181,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
super
().
__init__
()
super
().
__init__
()
config
:
MllamaConfig
=
vllm_config
.
model_config
.
hf_config
config
:
MllamaConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
...
@@ -1517,6 +1519,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1517,6 +1519,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
updated_params
.
add
(
name
)
updated_params
.
add
(
name
)
return
updated_params
return
updated_params
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
"multi_modal_projector"
,
tower_model
=
"vision_model"
)
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
for
mask
in
sparse_mask
:
for
mask
in
sparse_mask
:
...
...
vllm/worker/enc_dec_model_runner.py
View file @
f7030df3
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
LORA_WARMUP_RANK
=
8
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if
num_steps
>
1
:
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in "
raise
ValueError
(
"num_steps > 1 is not supported in "
"EncoderDecoderModelRunner"
)
"EncoderDecoderModelRunner"
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
(
model_input
.
attn_metadata
is
not
None
if
(
model_input
.
attn_metadata
is
not
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
dummy_lora_requests
=
self
.
_add_dummy_loras
(
self
.
lora_config
.
max_loras
)
assert
len
(
dummy_lora_requests
)
==
self
.
lora_config
.
max_loras
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the total
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables
=
None
,
block_tables
=
None
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
cross_block_table
=
None
,
cross_block_table
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
or
encoder_dummy_data
.
multi_modal_data
,
or
encoder_dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
decoder_dummy_data
.
multi_modal_placeholders
=
decoder_dummy_data
.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment