Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
828da0d4
Unverified
Commit
828da0d4
authored
Jun 06, 2024
by
Matthew Goldey
Committed by
GitHub
Jun 06, 2024
Browse files
[Frontend] enable passing multiple LoRA adapters at once to generate() (#5300)
parent
abe855d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
17 deletions
+91
-17
tests/entrypoints/test_llm_generate_multiple_loras.py
tests/entrypoints/test_llm_generate_multiple_loras.py
+69
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+22
-17
No files found.
tests/entrypoints/test_llm_generate_multiple_loras.py
0 → 100644
View file @
828da0d4
import
weakref
import
pytest
# downloading lora to test lora requests
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
from
vllm.lora.request
import
LoRARequest
from
..conftest
import
cleanup
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
PROMPTS
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
LORA_NAME
=
"typeof/zephyr-7b-beta-lora"
pytestmark
=
pytest
.
mark
.
llm
@
pytest
.
fixture
(
scope
=
"module"
)
def
llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
MODEL_NAME
,
tensor_parallel_size
=
1
,
max_model_len
=
8192
,
enable_lora
=
True
,
max_loras
=
4
,
max_lora_rank
=
64
,
max_num_seqs
=
128
,
enforce_eager
=
True
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup
()
@
pytest
.
fixture
(
scope
=
"session"
)
def
zephyr_lora_files
():
return
snapshot_download
(
repo_id
=
LORA_NAME
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_multiple_lora_requests
(
llm
:
LLM
,
zephyr_lora_files
):
lora_request
=
[
LoRARequest
(
LORA_NAME
,
idx
+
1
,
zephyr_lora_files
)
for
idx
in
range
(
len
(
PROMPTS
))
]
# Multiple SamplingParams should be matched with each prompt
outputs
=
llm
.
generate
(
PROMPTS
,
lora_request
=
lora_request
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
# Exception raised, if the size of params does not match the size of prompts
with
pytest
.
raises
(
ValueError
):
outputs
=
llm
.
generate
(
PROMPTS
,
lora_request
=
lora_request
[:
1
])
# Single LoRARequest should be applied to every prompt
single_lora_request
=
lora_request
[
0
]
outputs
=
llm
.
generate
(
PROMPTS
,
lora_request
=
single_lora_request
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
vllm/entrypoints/llm.py
View file @
828da0d4
...
@@ -170,7 +170,7 @@ class LLM:
...
@@ -170,7 +170,7 @@ class LLM:
List
[
SamplingParams
]]]
=
None
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
...
...
...
@@ -182,7 +182,7 @@ class LLM:
...
@@ -182,7 +182,7 @@ class LLM:
List
[
SamplingParams
]]]
=
None
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
...
...
...
@@ -195,7 +195,7 @@ class LLM:
...
@@ -195,7 +195,7 @@ class LLM:
*
,
*
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
...
...
...
@@ -208,7 +208,7 @@ class LLM:
...
@@ -208,7 +208,7 @@ class LLM:
*
,
*
,
prompt_token_ids
:
List
[
List
[
int
]],
prompt_token_ids
:
List
[
List
[
int
]],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
...
...
...
@@ -219,7 +219,7 @@ class LLM:
...
@@ -219,7 +219,7 @@ class LLM:
sampling_params
:
None
,
sampling_params
:
None
,
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
...
...
...
@@ -232,7 +232,7 @@ class LLM:
...
@@ -232,7 +232,7 @@ class LLM:
sampling_params
:
Optional
[
Union
[
SamplingParams
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
Sequence
[
SamplingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
...
...
...
@@ -249,7 +249,7 @@ class LLM:
...
@@ -249,7 +249,7 @@ class LLM:
Sequence
[
SamplingParams
]]]
=
None
,
Sequence
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
List
[
int
],
List
[
List
[
int
]]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
List
[
int
],
List
[
List
[
int
]]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -312,7 +312,7 @@ class LLM:
...
@@ -312,7 +312,7 @@ class LLM:
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
)
->
List
[
EmbeddingRequestOutput
]:
...
...
...
@@ -324,7 +324,7 @@ class LLM:
...
@@ -324,7 +324,7 @@ class LLM:
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
)
->
List
[
EmbeddingRequestOutput
]:
...
...
...
@@ -337,7 +337,7 @@ class LLM:
...
@@ -337,7 +337,7 @@ class LLM:
*
,
*
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
)
->
List
[
EmbeddingRequestOutput
]:
...
...
...
@@ -350,7 +350,7 @@ class LLM:
...
@@ -350,7 +350,7 @@ class LLM:
*
,
*
,
prompt_token_ids
:
List
[
List
[
int
]],
prompt_token_ids
:
List
[
List
[
int
]],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
)
->
List
[
EmbeddingRequestOutput
]:
...
...
...
@@ -361,7 +361,7 @@ class LLM:
...
@@ -361,7 +361,7 @@ class LLM:
pooling_params
:
None
,
pooling_params
:
None
,
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
prompt_token_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]],
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
)
->
List
[
EmbeddingRequestOutput
]:
...
...
...
@@ -374,7 +374,7 @@ class LLM:
...
@@ -374,7 +374,7 @@ class LLM:
pooling_params
:
Optional
[
Union
[
PoolingParams
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
)
->
List
[
EmbeddingRequestOutput
]:
...
...
...
@@ -391,7 +391,7 @@ class LLM:
...
@@ -391,7 +391,7 @@ class LLM:
Sequence
[
PoolingParams
]]]
=
None
,
Sequence
[
PoolingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
List
[
int
],
List
[
List
[
int
]]]]
=
None
,
prompt_token_ids
:
Optional
[
Union
[
List
[
int
],
List
[
List
[
int
]]]]
=
None
,
use_tqdm
:
bool
=
True
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
List
[
EmbeddingRequestOutput
]:
)
->
List
[
EmbeddingRequestOutput
]:
"""Generates the completions for the input prompts.
"""Generates the completions for the input prompts.
...
@@ -498,7 +498,7 @@ class LLM:
...
@@ -498,7 +498,7 @@ class LLM:
inputs
:
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
inputs
:
Union
[
PromptStrictInputs
,
Sequence
[
PromptStrictInputs
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]
],
)
->
None
:
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
if
isinstance
(
inputs
,
(
str
,
dict
)):
# Convert a single prompt to a list.
# Convert a single prompt to a list.
...
@@ -509,20 +509,25 @@ class LLM:
...
@@ -509,20 +509,25 @@ class LLM:
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
"must be the same."
)
if
isinstance
(
lora_request
,
list
)
and
len
(
lora_request
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and lora_request "
"must be the same."
)
# Add requests to the engine.
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inputs
):
for
i
,
request_inputs
in
enumerate
(
inputs
):
self
.
_add_request
(
self
.
_add_request
(
request_inputs
,
request_inputs
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
lora_request
=
lora_request
,
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
)
)
def
_add_request
(
def
_add_request
(
self
,
self
,
inputs
:
PromptInputs
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]
]
=
None
,
)
->
None
:
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
self
.
llm_engine
.
add_request
(
request_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment