Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db3bf7c9
Unverified
Commit
db3bf7c9
authored
Sep 05, 2024
by
Jiaxin Shan
Committed by
GitHub
Sep 05, 2024
Browse files
[Core] Support load and unload LoRA in api server (#6566)
Co-authored-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
2febcf27
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
336 additions
and
6 deletions
+336
-6
docs/requirements-docs.txt
docs/requirements-docs.txt
+0
-1
docs/source/models/lora.rst
docs/source/models/lora.rst
+52
-0
tests/entrypoints/llm/test_generate_multiple_loras.py
tests/entrypoints/llm/test_generate_multiple_loras.py
+1
-1
tests/entrypoints/openai/test_serving_engine.py
tests/entrypoints/openai/test_serving_engine.py
+107
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+38
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+10
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+78
-1
vllm/envs.py
vllm/envs.py
+7
-0
vllm/lora/request.py
vllm/lora/request.py
+18
-1
vllm/utils.py
vllm/utils.py
+25
-0
No files found.
docs/requirements-docs.txt
View file @
db3bf7c9
...
...
@@ -11,6 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
\ No newline at end of file
docs/source/models/lora.rst
View file @
db3bf7c9
...
...
@@ -107,3 +107,55 @@ The following is an example request
"max_tokens": 7,
"temperature": 0
}' | jq
Dynamically serving LoRA Adapters
---------------------------------
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
to change models on-the-fly is needed.
Note: Enabling this feature in production environments is risky as user may participate model adapter management.
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
.. code-block:: bash
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
Loading a LoRA Adapter:
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter.
Example request to load a LoRA adapter:
.. code-block:: bash
curl -X POST http://localhost:8000/v1/load_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter",
"lora_path": "/path/to/sql-lora-adapter"
}'
Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter
cannot be found or loaded, an appropriate error message will be returned.
Unloading a LoRA Adapter:
To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint
with the name or ID of the adapter to be unloaded.
Example request to unload a LoRA adapter:
.. code-block:: bash
curl -X POST http://localhost:8000/v1/unload_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter"
}'
tests/entrypoints/llm/test_generate_multiple_loras.py
View file @
db3bf7c9
...
...
@@ -50,7 +50,7 @@ def zephyr_lora_files():
@
pytest
.
mark
.
skip_global_cleanup
def
test_multiple_lora_requests
(
llm
:
LLM
,
zephyr_lora_files
):
lora_request
=
[
LoRARequest
(
LORA_NAME
,
idx
+
1
,
zephyr_lora_files
)
LoRARequest
(
LORA_NAME
+
str
(
idx
)
,
idx
+
1
,
zephyr_lora_files
)
for
idx
in
range
(
len
(
PROMPTS
))
]
# Multiple SamplingParams should be matched with each prompt
...
...
tests/entrypoints/openai/test_serving_engine.py
0 → 100644
View file @
db3bf7c9
from
http
import
HTTPStatus
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
LoadLoraAdapterRequest
,
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
MODEL_NAME
=
"meta-llama/Llama-2-7b"
LORA_LOADING_SUCCESS_MESSAGE
=
(
"Success: LoRA adapter '{lora_name}' added successfully."
)
LORA_UNLOADING_SUCCESS_MESSAGE
=
(
"Success: LoRA adapter '{lora_name}' removed successfully."
)
async
def
_async_serving_engine_init
():
mock_engine_client
=
MagicMock
(
spec
=
AsyncEngineClient
)
mock_model_config
=
MagicMock
(
spec
=
ModelConfig
)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config
.
max_model_len
=
2048
serving_engine
=
OpenAIServing
(
mock_engine_client
,
mock_model_config
,
served_model_names
=
[
MODEL_NAME
],
lora_modules
=
None
,
prompt_adapters
=
None
,
request_logger
=
None
)
return
serving_engine
@
pytest
.
mark
.
asyncio
async
def
test_load_lora_adapter_success
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter"
,
lora_path
=
"/path/to/adapter2"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
response
==
LORA_LOADING_SUCCESS_MESSAGE
.
format
(
lora_name
=
'adapter'
)
assert
len
(
serving_engine
.
lora_requests
)
==
1
assert
serving_engine
.
lora_requests
[
0
].
lora_name
==
"adapter"
@
pytest
.
mark
.
asyncio
async
def
test_load_lora_adapter_missing_fields
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
""
,
lora_path
=
""
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
@
pytest
.
mark
.
asyncio
async
def
test_load_lora_adapter_duplicate
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter1"
,
lora_path
=
"/path/to/adapter1"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
response
==
LORA_LOADING_SUCCESS_MESSAGE
.
format
(
lora_name
=
'adapter1'
)
assert
len
(
serving_engine
.
lora_requests
)
==
1
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter1"
,
lora_path
=
"/path/to/adapter1"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
assert
len
(
serving_engine
.
lora_requests
)
==
1
@
pytest
.
mark
.
asyncio
async
def
test_unload_lora_adapter_success
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
LoadLoraAdapterRequest
(
lora_name
=
"adapter1"
,
lora_path
=
"/path/to/adapter1"
)
response
=
await
serving_engine
.
load_lora_adapter
(
request
)
assert
len
(
serving_engine
.
lora_requests
)
==
1
request
=
UnloadLoraAdapterRequest
(
lora_name
=
"adapter1"
)
response
=
await
serving_engine
.
unload_lora_adapter
(
request
)
assert
response
==
LORA_UNLOADING_SUCCESS_MESSAGE
.
format
(
lora_name
=
'adapter1'
)
assert
len
(
serving_engine
.
lora_requests
)
==
0
@
pytest
.
mark
.
asyncio
async
def
test_unload_lora_adapter_missing_fields
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
UnloadLoraAdapterRequest
(
lora_name
=
""
,
lora_int_id
=
None
)
response
=
await
serving_engine
.
unload_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
@
pytest
.
mark
.
asyncio
async
def
test_unload_lora_adapter_not_found
():
serving_engine
=
await
_async_serving_engine_init
()
request
=
UnloadLoraAdapterRequest
(
lora_name
=
"nonexistent_adapter"
)
response
=
await
serving_engine
.
unload_lora_adapter
(
request
)
assert
isinstance
(
response
,
ErrorResponse
)
assert
response
.
type
==
"InvalidUserInput"
assert
response
.
code
==
HTTPStatus
.
BAD_REQUEST
vllm/entrypoints/openai/api_server.py
View file @
db3bf7c9
...
...
@@ -35,11 +35,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeResponse
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
LoadLoraAdapterRequest
,
TokenizeRequest
,
TokenizeResponse
)
# yapf: enable
TokenizeResponse
,
UnloadLoraAdapterRequest
)
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
...
...
@@ -343,6 +345,40 @@ if envs.VLLM_TORCH_PROFILER_DIR:
return
Response
(
status_code
=
200
)
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
logger
.
warning
(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
return
Response
(
status_code
=
200
,
content
=
response
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
):
response
=
await
openai_serving_chat
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
return
Response
(
status_code
=
200
,
content
=
response
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
include_router
(
router
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
db3bf7c9
...
...
@@ -878,3 +878,13 @@ class DetokenizeRequest(OpenAIBaseModel):
class
DetokenizeResponse
(
OpenAIBaseModel
):
prompt
:
str
class
LoadLoraAdapterRequest
(
BaseModel
):
lora_name
:
str
lora_path
:
str
class
UnloadLoraAdapterRequest
(
BaseModel
):
lora_name
:
str
lora_int_id
:
Optional
[
int
]
=
Field
(
default
=
None
)
vllm/entrypoints/openai/serving_engine.py
View file @
db3bf7c9
...
...
@@ -16,11 +16,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest
,
DetokenizeRequest
,
EmbeddingRequest
,
ErrorResponse
,
LoadLoraAdapterRequest
,
ModelCard
,
ModelList
,
ModelPermission
,
TokenizeChatRequest
,
TokenizeCompletionRequest
,
TokenizeRequest
)
TokenizeRequest
,
UnloadLoraAdapterRequest
)
# yapf: enable
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
...
...
@@ -32,6 +34,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
AtomicCounter
logger
=
init_logger
(
__name__
)
...
...
@@ -78,6 +81,7 @@ class OpenAIServing:
self
.
served_model_names
=
served_model_names
self
.
lora_id_counter
=
AtomicCounter
(
0
)
self
.
lora_requests
=
[]
if
lora_modules
is
not
None
:
self
.
lora_requests
=
[
...
...
@@ -403,3 +407,76 @@ class OpenAIServing:
if
logprob
.
decoded_token
is
not
None
:
return
logprob
.
decoded_token
return
tokenizer
.
decode
(
token_id
)
async
def
_check_load_lora_adapter_request
(
self
,
request
:
LoadLoraAdapterRequest
)
->
Optional
[
ErrorResponse
]:
# Check if both 'lora_name' and 'lora_path' are provided
if
not
request
.
lora_name
or
not
request
.
lora_path
:
return
self
.
create_error_response
(
message
=
"Both 'lora_name' and 'lora_path' must be provided."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# Check if the lora adapter with the given name already exists
if
any
(
lora_request
.
lora_name
==
request
.
lora_name
for
lora_request
in
self
.
lora_requests
):
return
self
.
create_error_response
(
message
=
f
"The lora adapter '
{
request
.
lora_name
}
' has already been"
"loaded."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
return
None
async
def
_check_unload_lora_adapter_request
(
self
,
request
:
UnloadLoraAdapterRequest
)
->
Optional
[
ErrorResponse
]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if
not
request
.
lora_name
and
not
request
.
lora_int_id
:
return
self
.
create_error_response
(
message
=
"either 'lora_name' and 'lora_int_id' needs to be provided."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
# Check if the lora adapter with the given name exists
if
not
any
(
lora_request
.
lora_name
==
request
.
lora_name
for
lora_request
in
self
.
lora_requests
):
return
self
.
create_error_response
(
message
=
f
"The lora adapter '
{
request
.
lora_name
}
' cannot be found."
,
err_type
=
"InvalidUserInput"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
return
None
async
def
load_lora_adapter
(
self
,
request
:
LoadLoraAdapterRequest
)
->
Union
[
ErrorResponse
,
str
]:
error_check_ret
=
await
self
.
_check_load_lora_adapter_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
lora_name
,
lora_path
=
request
.
lora_name
,
request
.
lora_path
unique_id
=
self
.
lora_id_counter
.
inc
(
1
)
self
.
lora_requests
.
append
(
LoRARequest
(
lora_name
=
lora_name
,
lora_int_id
=
unique_id
,
lora_path
=
lora_path
))
return
f
"Success: LoRA adapter '
{
lora_name
}
' added successfully."
async
def
unload_lora_adapter
(
self
,
request
:
UnloadLoraAdapterRequest
)
->
Union
[
ErrorResponse
,
str
]:
error_check_ret
=
await
self
.
_check_unload_lora_adapter_request
(
request
)
if
error_check_ret
is
not
None
:
return
error_check_ret
lora_name
=
request
.
lora_name
self
.
lora_requests
=
[
lora_request
for
lora_request
in
self
.
lora_requests
if
lora_request
.
lora_name
!=
lora_name
]
return
f
"Success: LoRA adapter '
{
lora_name
}
' removed successfully."
vllm/envs.py
View file @
db3bf7c9
...
...
@@ -61,6 +61,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_ENGINE_USE_RAY
:
bool
=
False
VLLM_PLUGINS
:
Optional
[
List
[
str
]]
=
None
VLLM_TORCH_PROFILER_DIR
:
Optional
[
str
]
=
None
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -409,6 +410,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_AWQ"
,
"0"
))),
# If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
}
# end-env-vars-definition
...
...
vllm/lora/request.py
View file @
db3bf7c9
...
...
@@ -28,7 +28,6 @@ class LoRARequest(
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
__hash__
=
AdapterRequest
.
__hash__
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__struct_fields__
:
...
...
@@ -75,3 +74,21 @@ class LoRARequest(
DeprecationWarning
,
stacklevel
=
2
)
self
.
lora_path
=
value
def
__eq__
(
self
,
value
:
object
)
->
bool
:
"""
Overrides the equality method to compare LoRARequest
instances based on lora_name. This allows for identification
and comparison lora adapter across engines.
"""
return
isinstance
(
value
,
self
.
__class__
)
and
self
.
lora_name
==
value
.
lora_name
def
__hash__
(
self
)
->
int
:
"""
Overrides the hash method to hash LoRARequest instances
based on lora_name. This ensures that LoRARequest instances
can be used in hash-based collections such as sets and dictionaries,
identified by their names across engines.
"""
return
hash
(
self
.
lora_name
)
vllm/utils.py
View file @
db3bf7c9
...
...
@@ -1224,3 +1224,28 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
def
supports_dynamo
()
->
bool
:
base_torch_version
=
Version
(
Version
(
torch
.
__version__
).
base_version
)
return
base_torch_version
>=
Version
(
"2.4.0"
)
class
AtomicCounter
:
"""An atomic, thread-safe counter"""
def
__init__
(
self
,
initial
=
0
):
"""Initialize a new atomic counter to given initial value"""
self
.
_value
=
initial
self
.
_lock
=
threading
.
Lock
()
def
inc
(
self
,
num
=
1
):
"""Atomically increment the counter by num and return the new value"""
with
self
.
_lock
:
self
.
_value
+=
num
return
self
.
_value
def
dec
(
self
,
num
=
1
):
"""Atomically decrement the counter by num and return the new value"""
with
self
.
_lock
:
self
.
_value
-=
num
return
self
.
_value
@
property
def
value
(
self
):
return
self
.
_value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment