Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8132365b
Unverified
Commit
8132365b
authored
May 11, 2025
by
Ben Browning
Committed by
GitHub
May 11, 2025
Browse files
[Bugfix]: v1 engine - consider lora adapters in allowed_token_ids (#17855)
Signed-off-by:
Ben Browning
<
bbrownin@redhat.com
>
parent
eea22a56
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
5 deletions
+154
-5
tests/lora/conftest.py
tests/lora/conftest.py
+14
-2
tests/lora/test_lora_allowed_token_ids.py
tests/lora/test_lora_allowed_token_ids.py
+134
-0
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+6
-3
No files found.
tests/lora/conftest.py
View file @
8132365b
...
@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
...
@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
return
model
return
model
@
pytest
.
fixture
(
scope
=
"session"
)
def
llama_2_7b_base_huggingface_id
():
# used as a base model for testing with sql lora adapter
return
"meta-llama/Llama-2-7b-hf"
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
sql_lora_huggingface_id
():
def
sql_lora_huggingface_id
():
# huggingface repo id is used to test lora runtime downloading.
# huggingface repo id is used to test lora runtime downloading.
...
@@ -198,6 +204,12 @@ def qwen2vl_lora_files():
...
@@ -198,6 +204,12 @@ def qwen2vl_lora_files():
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen2-vl-lora-pokemon"
)
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen2-vl-lora-pokemon"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
qwen25vl_base_huggingface_id
():
# used as a base model for testing with qwen25vl lora adapter
return
"Qwen/Qwen2.5-VL-3B-Instruct"
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
qwen25vl_lora_files
():
def
qwen25vl_lora_files
():
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen25-vl-lora-pokemon"
)
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen25-vl-lora-pokemon"
)
...
...
tests/lora/test_lora_allowed_token_ids.py
0 → 100644
View file @
8132365b
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
VllmConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.v1.engine.processor
import
Processor
def
test_allowed_token_ids_with_lora_vocab
(
llama_2_7b_base_huggingface_id
,
sql_lora_files
):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that define additional tokens.
"""
# Setup a base model compatible with the sql_lora_files adapter and
# a known number of tokens in the base model.
model_config
=
ModelConfig
(
model
=
llama_2_7b_base_huggingface_id
,
tokenizer
=
llama_2_7b_base_huggingface_id
,
tokenizer_mode
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
device_config
=
DeviceConfig
(),
lora_config
=
LoRAConfig
(),
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
processor
=
Processor
(
vllm_config
,
tokenizer
)
lora_request
=
LoRARequest
(
"1"
,
1
,
str
(
sql_lora_files
))
request_id
=
"1"
prompt
=
"a prompt"
# tokens added in the lora adapter should not raise an error
lora_token_ids
=
[
32000
,
32001
,
32002
,
32003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
lora_token_ids
),
lora_request
=
lora_request
)
# tokens in the base model should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
lora_request
=
lora_request
)
# tokens not in the lora adapter should raise an error
invalid_token_ids
=
[
35000
,
35001
,
35002
,
35003
]
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
invalid_token_ids
),
lora_request
=
lora_request
)
# tokens in the lora adapter with no lora request should raise an error
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
lora_token_ids
),
)
def
test_allowed_token_ids_with_lora_adapter_no_vocab
(
qwen25vl_base_huggingface_id
,
qwen25vl_lora_files
):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that do not define additional tokens.
"""
# Setup a base model compatible with the qwen25vl_lora_files adapter and
# a known number of tokens in the base model.
model_config
=
ModelConfig
(
model
=
qwen25vl_base_huggingface_id
,
tokenizer
=
qwen25vl_base_huggingface_id
,
tokenizer_mode
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
device_config
=
DeviceConfig
(),
lora_config
=
LoRAConfig
(),
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
processor
=
Processor
(
vllm_config
,
tokenizer
)
lora_request
=
LoRARequest
(
"1"
,
1
,
str
(
qwen25vl_lora_files
))
request_id
=
"1"
prompt
=
"a prompt"
# tokens in the base model should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
lora_request
=
lora_request
)
# tokens in the base model with no lora request should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
)
# tokens not in the base model should raise an error
invalid_token_ids
=
[
200000
,
200001
,
200002
,
200003
]
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
invalid_token_ids
),
lora_request
=
lora_request
)
vllm/v1/engine/processor.py
View file @
8132365b
...
@@ -74,6 +74,7 @@ class Processor:
...
@@ -74,6 +74,7 @@ class Processor:
def
_validate_sampling_params
(
def
_validate_sampling_params
(
self
,
self
,
params
:
SamplingParams
,
params
:
SamplingParams
,
lora_request
:
Optional
[
LoRARequest
],
)
->
None
:
)
->
None
:
self
.
_validate_structured_output
(
params
)
self
.
_validate_structured_output
(
params
)
self
.
_validate_logit_bias
(
params
)
self
.
_validate_logit_bias
(
params
)
...
@@ -82,7 +83,8 @@ class Processor:
...
@@ -82,7 +83,8 @@ class Processor:
return
return
if
not
params
.
allowed_token_ids
:
if
not
params
.
allowed_token_ids
:
raise
ValueError
(
"allowed_token_ids is not None and empty!"
)
raise
ValueError
(
"allowed_token_ids is not None and empty!"
)
vocab_size
=
self
.
model_config
.
get_vocab_size
()
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
vocab_size
=
len
(
tokenizer
)
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
params
.
allowed_token_ids
):
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
params
.
allowed_token_ids
):
raise
ValueError
(
raise
ValueError
(
"allowed_token_ids contains out-of-vocab token id!"
)
"allowed_token_ids contains out-of-vocab token id!"
)
...
@@ -122,6 +124,7 @@ class Processor:
...
@@ -122,6 +124,7 @@ class Processor:
def
_validate_params
(
def
_validate_params
(
self
,
self
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
],
):
):
"""
"""
Validate supported SamplingParam.
Validate supported SamplingParam.
...
@@ -132,7 +135,7 @@ class Processor:
...
@@ -132,7 +135,7 @@ class Processor:
raise
ValueError
(
"V1 does not yet support Pooling models."
)
raise
ValueError
(
"V1 does not yet support Pooling models."
)
self
.
_validate_logprobs
(
params
)
self
.
_validate_logprobs
(
params
)
self
.
_validate_sampling_params
(
params
)
self
.
_validate_sampling_params
(
params
,
lora_request
)
self
.
_validate_supported_sampling_params
(
params
)
self
.
_validate_supported_sampling_params
(
params
)
def
_validate_lora
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
None
:
def
_validate_lora
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
None
:
...
@@ -207,7 +210,7 @@ class Processor:
...
@@ -207,7 +210,7 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
# TODO(woosuk): Support encoder-decoder models.
self
.
_validate_lora
(
lora_request
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_params
(
params
)
self
.
_validate_params
(
params
,
lora_request
)
if
priority
!=
0
:
if
priority
!=
0
:
raise
ValueError
(
"V1 does not support priority yet."
)
raise
ValueError
(
"V1 does not support priority yet."
)
if
trace_headers
is
not
None
:
if
trace_headers
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment