Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
03f48b3d
"vscode:/vscode.git/clone" did not exist on "99ffef472b5f2e56269f019ffff42e526a7b5814"
Unverified
Commit
03f48b3d
authored
Feb 25, 2025
by
Varun Sundar Rabindranath
Committed by
GitHub
Feb 25, 2025
Browse files
[Core] LoRA V1 - Add add/pin/list/remove_lora functions (#13705)
parent
4d251ad0
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
270 additions
and
22 deletions
+270
-22
tests/lora/test_add_lora.py
tests/lora/test_add_lora.py
+9
-4
tests/lora/test_lora_functions.py
tests/lora/test_lora_functions.py
+137
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+15
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+12
-3
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+54
-9
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+17
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+10
-1
vllm/v1/worker/lora_model_runner_mixin.py
vllm/v1/worker/lora_model_runner_mixin.py
+16
-1
No files found.
tests/lora/test_add_lora.py
View file @
03f48b3d
...
@@ -7,6 +7,7 @@ from typing import List
...
@@ -7,6 +7,7 @@ from typing import List
import
pytest
import
pytest
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
import
vllm.envs
as
env
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.inputs
import
TextPrompt
from
vllm.inputs
import
TextPrompt
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -144,10 +145,14 @@ async def test_add_lora():
...
@@ -144,10 +145,14 @@ async def test_add_lora():
await
requests_processing_time
(
llm
,
dummy_run_requests
)
await
requests_processing_time
(
llm
,
dummy_run_requests
)
# Run with warmup
# Run with warmup
for
lr
in
warmup_run_requests
:
add_lora_tasks
=
[
llm
.
add_lora
(
lr
)
for
lr
in
warmup_run_requests
]
await
llm
.
add_lora
(
lr
)
add_lora_results
=
await
asyncio
.
gather
(
*
add_lora_tasks
)
# Wait for the add_lora function to complete on the server side.
if
env
.
VLLM_USE_V1
:
await
asyncio
.
sleep
(
30
)
# Test that all all_lora calls are successful.
assert
all
(
add_lora_results
)
else
:
# No way to check V0 engine results as the calls just return None.
pass
time_with_add_lora
=
await
requests_processing_time
(
time_with_add_lora
=
await
requests_processing_time
(
llm
,
warmup_run_requests
)
llm
,
warmup_run_requests
)
...
...
tests/lora/test_lora_functions.py
0 → 100644
View file @
03f48b3d
# SPDX-License-Identifier: Apache-2.0
"""
Script to test add_lora, remove_lora, pin_lora, list_loras functions.
"""
import
os
from
typing
import
List
import
pytest
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.entrypoints.llm
import
LLM
from
vllm.lora.request
import
LoRARequest
MODEL_PATH
=
"meta-llama/Llama-2-7b-hf"
LORA_MODULE_PATH
=
"yard1/llama-2-7b-sql-lora-test"
LORA_RANK
=
8
@
pytest
.
fixture
(
autouse
=
True
)
def
v1
(
run_with_both_engines_lora
):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
def
make_lora_request
(
lora_id
:
int
):
return
LoRARequest
(
lora_name
=
f
"
{
lora_id
}
"
,
lora_int_id
=
lora_id
,
lora_path
=
LORA_MODULE_PATH
)
def
test_lora_functions_sync
():
max_loras
=
4
# Create engine in eager-mode. Due to high max_loras, the CI can
# OOM during cuda-graph capture.
engine_args
=
EngineArgs
(
model
=
MODEL_PATH
,
enable_lora
=
True
,
max_loras
=
max_loras
,
max_lora_rank
=
LORA_RANK
,
max_model_len
=
128
,
gpu_memory_utilization
=
0.8
,
enforce_eager
=
True
)
llm
=
LLM
.
get_engine_class
().
from_engine_args
(
engine_args
)
def
run_check
(
fn
,
args
,
expected
:
List
):
fn
(
args
)
assert
set
(
llm
.
list_loras
())
==
set
(
expected
)
run_check
(
llm
.
add_lora
,
make_lora_request
(
1
),
[
1
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
2
),
[
1
,
2
])
# Pin LoRA 1 and test that it is never removed on subsequent adds.
run_check
(
llm
.
pin_lora
,
1
,
[
1
,
2
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
3
),
[
1
,
2
,
3
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
4
),
[
1
,
2
,
3
,
4
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
5
),
[
1
,
5
,
3
,
4
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
6
),
[
1
,
5
,
6
,
4
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
7
),
[
1
,
5
,
6
,
7
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
8
),
[
1
,
8
,
6
,
7
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
9
),
[
1
,
8
,
9
,
7
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
10
),
[
1
,
8
,
9
,
10
])
# Remove LoRA 1 and continue adding.
run_check
(
llm
.
remove_lora
,
1
,
[
8
,
9
,
10
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
11
),
[
8
,
9
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
12
),
[
12
,
9
,
10
,
11
])
run_check
(
llm
.
add_lora
,
make_lora_request
(
13
),
[
12
,
13
,
10
,
11
])
# Remove all LoRAs
run_check
(
llm
.
remove_lora
,
13
,
[
12
,
10
,
11
])
run_check
(
llm
.
remove_lora
,
12
,
[
10
,
11
])
run_check
(
llm
.
remove_lora
,
11
,
[
10
])
run_check
(
llm
.
remove_lora
,
10
,
[])
@
pytest
.
mark
.
asyncio
async
def
test_lora_functions_async
():
if
os
.
getenv
(
"VLLM_USE_V1"
)
==
"0"
:
pytest
.
skip
(
reason
=
"V0 AsyncLLMEngine does not expose remove/list/pin LoRA functions"
)
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
# environment variable. reload vllm.enging.async_llm_engine as
# vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the
# env var.
import
importlib
import
vllm.engine.async_llm_engine
importlib
.
reload
(
vllm
.
engine
.
async_llm_engine
)
from
vllm.entrypoints.openai.api_server
import
(
build_async_engine_client_from_engine_args
)
max_loras
=
4
engine_args
=
AsyncEngineArgs
(
model
=
MODEL_PATH
,
enable_lora
=
True
,
max_loras
=
max_loras
,
max_lora_rank
=
LORA_RANK
,
max_model_len
=
128
,
gpu_memory_utilization
=
0.8
,
enforce_eager
=
True
)
async
def
run_check
(
fn
,
args
,
expected
:
List
):
await
fn
(
args
)
assert
set
(
await
llm
.
list_loras
())
==
set
(
expected
)
async
with
build_async_engine_client_from_engine_args
(
engine_args
)
as
llm
:
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
1
),
[
1
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
2
),
[
1
,
2
])
# Pin LoRA 1 and test that it is never removed on subsequent adds.
await
run_check
(
llm
.
pin_lora
,
1
,
[
1
,
2
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
3
),
[
1
,
2
,
3
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
4
),
[
1
,
2
,
3
,
4
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
5
),
[
1
,
5
,
3
,
4
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
6
),
[
1
,
5
,
6
,
4
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
7
),
[
1
,
5
,
6
,
7
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
8
),
[
1
,
8
,
6
,
7
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
9
),
[
1
,
8
,
9
,
7
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
10
),
[
1
,
8
,
9
,
10
])
# Remove LoRA 1 and continue adding.
await
run_check
(
llm
.
remove_lora
,
1
,
[
8
,
9
,
10
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
11
),
[
8
,
9
,
10
,
11
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
12
),
[
12
,
9
,
10
,
11
])
await
run_check
(
llm
.
add_lora
,
make_lora_request
(
13
),
[
12
,
13
,
10
,
11
])
# Remove all LoRAs
await
run_check
(
llm
.
remove_lora
,
13
,
[
12
,
10
,
11
])
await
run_check
(
llm
.
remove_lora
,
12
,
[
10
,
11
])
await
run_check
(
llm
.
remove_lora
,
11
,
[
10
])
await
run_check
(
llm
.
remove_lora
,
10
,
[])
vllm/v1/engine/async_llm.py
View file @
03f48b3d
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
asyncio
import
asyncio
import
os
import
os
from
typing
import
AsyncGenerator
,
List
,
Mapping
,
Optional
,
Type
,
Union
from
typing
import
AsyncGenerator
,
List
,
Mapping
,
Optional
,
Set
,
Type
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -392,9 +392,21 @@ class AsyncLLM(EngineClient):
...
@@ -392,9 +392,21 @@ class AsyncLLM(EngineClient):
async
def
wake_up
(
self
)
->
None
:
async
def
wake_up
(
self
)
->
None
:
await
self
.
engine_core
.
wake_up_async
()
await
self
.
engine_core
.
wake_up_async
()
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
async
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
"""Load a new LoRA adapter into the engine for future requests."""
"""Load a new LoRA adapter into the engine for future requests."""
await
self
.
engine_core
.
add_lora_async
(
lora_request
)
return
await
self
.
engine_core
.
add_lora_async
(
lora_request
)
async
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Remove an already loaded LoRA adapter."""
return
await
self
.
engine_core
.
remove_lora_async
(
lora_id
)
async
def
list_loras
(
self
)
->
Set
[
int
]:
"""List all registered adapters."""
return
await
self
.
engine_core
.
list_loras_async
()
async
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Prevent an adapter from being evicted."""
return
await
self
.
engine_core
.
pin_lora_async
(
lora_id
)
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
...
...
vllm/v1/engine/core.py
View file @
03f48b3d
...
@@ -7,7 +7,7 @@ import time
...
@@ -7,7 +7,7 @@ import time
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
inspect
import
isclass
,
signature
from
inspect
import
isclass
,
signature
from
multiprocessing.connection
import
Connection
from
multiprocessing.connection
import
Connection
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
List
,
Optional
,
Set
,
Tuple
,
Type
import
msgspec
import
msgspec
import
psutil
import
psutil
...
@@ -222,8 +222,17 @@ class EngineCore:
...
@@ -222,8 +222,17 @@ class EngineCore:
def
execute_dummy_batch
(
self
):
def
execute_dummy_batch
(
self
):
self
.
model_executor
.
collective_rpc
(
"execute_dummy_batch"
)
self
.
model_executor
.
collective_rpc
(
"execute_dummy_batch"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
self
.
model_executor
.
add_lora
(
lora_request
)
return
self
.
model_executor
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
model_executor
.
list_loras
()
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
class
EngineCoreProc
(
EngineCore
):
class
EngineCoreProc
(
EngineCore
):
...
...
vllm/v1/engine/core_client.py
View file @
03f48b3d
...
@@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
...
@@ -10,7 +10,7 @@ from abc import ABC, abstractmethod
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
,
Union
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
...
@@ -97,7 +97,16 @@ class EngineCoreClient(ABC):
...
@@ -97,7 +97,16 @@ class EngineCoreClient(ABC):
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
...
@@ -121,7 +130,16 @@ class EngineCoreClient(ABC):
...
@@ -121,7 +130,16 @@ class EngineCoreClient(ABC):
async
def
abort_requests_async
(
self
,
request_ids
:
List
[
str
])
->
None
:
async
def
abort_requests_async
(
self
,
request_ids
:
List
[
str
])
->
None
:
raise
NotImplementedError
raise
NotImplementedError
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
None
:
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
raise
NotImplementedError
async
def
remove_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
async
def
list_loras_async
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -166,8 +184,17 @@ class InprocClient(EngineCoreClient):
...
@@ -166,8 +184,17 @@ class InprocClient(EngineCoreClient):
def
execute_dummy_batch
(
self
)
->
None
:
def
execute_dummy_batch
(
self
)
->
None
:
self
.
engine_core
.
execute_dummy_batch
()
self
.
engine_core
.
execute_dummy_batch
()
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
self
.
engine_core
.
add_lora
(
lora_request
)
return
self
.
engine_core
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
engine_core
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
engine_core
.
list_loras
()
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
engine_core
.
pin_lora
(
lora_id
)
@
dataclass
@
dataclass
...
@@ -356,8 +383,17 @@ class SyncMPClient(MPClient):
...
@@ -356,8 +383,17 @@ class SyncMPClient(MPClient):
def
reset_prefix_cache
(
self
)
->
None
:
def
reset_prefix_cache
(
self
)
->
None
:
self
.
_call_utility
(
"reset_prefix_cache"
)
self
.
_call_utility
(
"reset_prefix_cache"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
self
.
_call_utility
(
"add_lora"
,
lora_request
)
return
self
.
_call_utility
(
"add_lora"
,
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_call_utility
(
"remove_lora"
,
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
_call_utility
(
"list_loras"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_call_utility
(
"pin_lora"
,
lora_id
)
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
self
.
_call_utility
(
"sleep"
,
level
)
self
.
_call_utility
(
"sleep"
,
level
)
...
@@ -454,5 +490,14 @@ class AsyncMPClient(MPClient):
...
@@ -454,5 +490,14 @@ class AsyncMPClient(MPClient):
async
def
execute_dummy_batch_async
(
self
)
->
None
:
async
def
execute_dummy_batch_async
(
self
)
->
None
:
await
self
.
_call_utility_async
(
"execute_dummy_batch"
)
await
self
.
_call_utility_async
(
"execute_dummy_batch"
)
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
None
:
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
await
self
.
_call_utility_async
(
"add_lora"
,
lora_request
)
return
await
self
.
_call_utility_async
(
"add_lora"
,
lora_request
)
async
def
remove_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
_call_utility_async
(
"remove_lora"
,
lora_id
)
async
def
list_loras_async
(
self
)
->
Set
[
int
]:
return
await
self
.
_call_utility_async
(
"list_loras"
)
async
def
pin_lora_async
(
self
,
lora_id
:
int
)
->
bool
:
return
await
self
.
_call_utility_async
(
"pin_lora"
,
lora_id
)
vllm/v1/engine/llm_engine.py
View file @
03f48b3d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Type
,
Union
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Type
,
Union
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
...
@@ -254,3 +254,19 @@ class LLMEngine:
...
@@ -254,3 +254,19 @@ class LLMEngine:
f
"found type:
{
type
(
tokenizer_group
)
}
"
)
f
"found type:
{
type
(
tokenizer_group
)
}
"
)
return
tokenizer_group
return
tokenizer_group
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
"""Load a new LoRA adapter into the engine for future requests."""
return
self
.
engine_core
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Remove an already loaded LoRA adapter."""
return
self
.
engine_core
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
"""List all registered adapters."""
return
self
.
engine_core
.
list_loras
()
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Prevent an adapter from being evicted."""
return
self
.
engine_core
.
pin_lora
(
lora_id
)
vllm/v1/worker/gpu_worker.py
View file @
03f48b3d
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
,
Set
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -240,6 +240,15 @@ class Worker(WorkerBase):
...
@@ -240,6 +240,15 @@ class Worker(WorkerBase):
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
return
self
.
model_runner
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
model_runner
.
list_loras
()
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
pin_lora
(
lora_id
)
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
# worker will always be healthy as long as it's running.
# worker will always be healthy as long as it's running.
return
return
...
...
vllm/v1/worker/lora_model_runner_mixin.py
View file @
03f48b3d
...
@@ -132,3 +132,18 @@ class LoRAModelRunnerMixin:
...
@@ -132,3 +132,18 @@ class LoRAModelRunnerMixin:
if
not
self
.
lora_manager
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
add_adapter
(
lora_request
)
return
self
.
lora_manager
.
add_adapter
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_adapter
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
pin_adapter
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
list_adapters
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment