Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8132365b
Unverified
Commit
8132365b
authored
May 11, 2025
by
Ben Browning
Committed by
GitHub
May 11, 2025
Browse files
[Bugfix]: v1 engine - consider lora adapters in allowed_token_ids (#17855)
Signed-off-by:
Ben Browning
<
bbrownin@redhat.com
>
parent
eea22a56
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
5 deletions
+154
-5
tests/lora/conftest.py
tests/lora/conftest.py
+14
-2
tests/lora/test_lora_allowed_token_ids.py
tests/lora/test_lora_allowed_token_ids.py
+134
-0
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+6
-3
No files found.
tests/lora/conftest.py
View file @
8132365b
...
@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
...
@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
return
model
return
model
@
pytest
.
fixture
(
scope
=
"session"
)
def
llama_2_7b_base_huggingface_id
():
# used as a base model for testing with sql lora adapter
return
"meta-llama/Llama-2-7b-hf"
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
sql_lora_huggingface_id
():
def
sql_lora_huggingface_id
():
# huggingface repo id is used to test lora runtime downloading.
# huggingface repo id is used to test lora runtime downloading.
...
@@ -198,6 +204,12 @@ def qwen2vl_lora_files():
...
@@ -198,6 +204,12 @@ def qwen2vl_lora_files():
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen2-vl-lora-pokemon"
)
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen2-vl-lora-pokemon"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
qwen25vl_base_huggingface_id
():
# used as a base model for testing with qwen25vl lora adapter
return
"Qwen/Qwen2.5-VL-3B-Instruct"
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
qwen25vl_lora_files
():
def
qwen25vl_lora_files
():
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen25-vl-lora-pokemon"
)
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen25-vl-lora-pokemon"
)
...
@@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
...
@@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
@
pytest
.
fixture
@
pytest
.
fixture
def
reset_default_device
():
def
reset_default_device
():
"""
"""
Some tests, such as `test_punica_ops.py`, explicitly set the
Some tests, such as `test_punica_ops.py`, explicitly set the
default device, which can affect subsequent tests. Adding this fixture
default device, which can affect subsequent tests. Adding this fixture
helps avoid this problem.
helps avoid this problem.
"""
"""
original_device
=
torch
.
get_default_device
()
original_device
=
torch
.
get_default_device
()
...
...
tests/lora/test_lora_allowed_token_ids.py
0 → 100644
View file @
8132365b
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
VllmConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.v1.engine.processor
import
Processor
def
test_allowed_token_ids_with_lora_vocab
(
llama_2_7b_base_huggingface_id
,
sql_lora_files
):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that define additional tokens.
"""
# Setup a base model compatible with the sql_lora_files adapter and
# a known number of tokens in the base model.
model_config
=
ModelConfig
(
model
=
llama_2_7b_base_huggingface_id
,
tokenizer
=
llama_2_7b_base_huggingface_id
,
tokenizer_mode
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
device_config
=
DeviceConfig
(),
lora_config
=
LoRAConfig
(),
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
processor
=
Processor
(
vllm_config
,
tokenizer
)
lora_request
=
LoRARequest
(
"1"
,
1
,
str
(
sql_lora_files
))
request_id
=
"1"
prompt
=
"a prompt"
# tokens added in the lora adapter should not raise an error
lora_token_ids
=
[
32000
,
32001
,
32002
,
32003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
lora_token_ids
),
lora_request
=
lora_request
)
# tokens in the base model should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
lora_request
=
lora_request
)
# tokens not in the lora adapter should raise an error
invalid_token_ids
=
[
35000
,
35001
,
35002
,
35003
]
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
invalid_token_ids
),
lora_request
=
lora_request
)
# tokens in the lora adapter with no lora request should raise an error
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
lora_token_ids
),
)
def
test_allowed_token_ids_with_lora_adapter_no_vocab
(
qwen25vl_base_huggingface_id
,
qwen25vl_lora_files
):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that do not define additional tokens.
"""
# Setup a base model compatible with the qwen25vl_lora_files adapter and
# a known number of tokens in the base model.
model_config
=
ModelConfig
(
model
=
qwen25vl_base_huggingface_id
,
tokenizer
=
qwen25vl_base_huggingface_id
,
tokenizer_mode
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
device_config
=
DeviceConfig
(),
lora_config
=
LoRAConfig
(),
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
processor
=
Processor
(
vllm_config
,
tokenizer
)
lora_request
=
LoRARequest
(
"1"
,
1
,
str
(
qwen25vl_lora_files
))
request_id
=
"1"
prompt
=
"a prompt"
# tokens in the base model should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
lora_request
=
lora_request
)
# tokens in the base model with no lora request should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
)
# tokens not in the base model should raise an error
invalid_token_ids
=
[
200000
,
200001
,
200002
,
200003
]
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
invalid_token_ids
),
lora_request
=
lora_request
)
vllm/v1/engine/processor.py
View file @
8132365b
...
@@ -74,6 +74,7 @@ class Processor:
...
@@ -74,6 +74,7 @@ class Processor:
def
_validate_sampling_params
(
def
_validate_sampling_params
(
self
,
self
,
params
:
SamplingParams
,
params
:
SamplingParams
,
lora_request
:
Optional
[
LoRARequest
],
)
->
None
:
)
->
None
:
self
.
_validate_structured_output
(
params
)
self
.
_validate_structured_output
(
params
)
self
.
_validate_logit_bias
(
params
)
self
.
_validate_logit_bias
(
params
)
...
@@ -82,7 +83,8 @@ class Processor:
...
@@ -82,7 +83,8 @@ class Processor:
return
return
if
not
params
.
allowed_token_ids
:
if
not
params
.
allowed_token_ids
:
raise
ValueError
(
"allowed_token_ids is not None and empty!"
)
raise
ValueError
(
"allowed_token_ids is not None and empty!"
)
vocab_size
=
self
.
model_config
.
get_vocab_size
()
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
)
vocab_size
=
len
(
tokenizer
)
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
params
.
allowed_token_ids
):
if
not
all
(
0
<=
tid
<
vocab_size
for
tid
in
params
.
allowed_token_ids
):
raise
ValueError
(
raise
ValueError
(
"allowed_token_ids contains out-of-vocab token id!"
)
"allowed_token_ids contains out-of-vocab token id!"
)
...
@@ -122,6 +124,7 @@ class Processor:
...
@@ -122,6 +124,7 @@ class Processor:
def
_validate_params
(
def
_validate_params
(
self
,
self
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
],
):
):
"""
"""
Validate supported SamplingParam.
Validate supported SamplingParam.
...
@@ -132,7 +135,7 @@ class Processor:
...
@@ -132,7 +135,7 @@ class Processor:
raise
ValueError
(
"V1 does not yet support Pooling models."
)
raise
ValueError
(
"V1 does not yet support Pooling models."
)
self
.
_validate_logprobs
(
params
)
self
.
_validate_logprobs
(
params
)
self
.
_validate_sampling_params
(
params
)
self
.
_validate_sampling_params
(
params
,
lora_request
)
self
.
_validate_supported_sampling_params
(
params
)
self
.
_validate_supported_sampling_params
(
params
)
def
_validate_lora
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
None
:
def
_validate_lora
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
None
:
...
@@ -207,7 +210,7 @@ class Processor:
...
@@ -207,7 +210,7 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
# TODO(woosuk): Support encoder-decoder models.
self
.
_validate_lora
(
lora_request
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_params
(
params
)
self
.
_validate_params
(
params
,
lora_request
)
if
priority
!=
0
:
if
priority
!=
0
:
raise
ValueError
(
"V1 does not support priority yet."
)
raise
ValueError
(
"V1 does not support priority yet."
)
if
trace_headers
is
not
None
:
if
trace_headers
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment