Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f5dda63e
Unverified
Commit
f5dda63e
authored
Jun 21, 2024
by
rohithkrn
Committed by
GitHub
Jun 21, 2024
Browse files
[LoRA] Add support for pinning lora adapters in the LRU cache (#5603)
parent
71875073
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
171 additions
and
5 deletions
+171
-5
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+64
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-0
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+3
-0
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+7
-0
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+4
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+4
-0
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+3
-0
vllm/lora/models.py
vllm/lora/models.py
+26
-0
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+3
-0
vllm/utils.py
vllm/utils.py
+38
-5
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+5
-0
vllm/worker/worker.py
vllm/worker/worker.py
+3
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+8
-0
No files found.
tests/lora/test_lora_manager.py
View file @
f5dda63e
...
@@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
...
@@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
assert
manager
.
activate_lora
(
3
)
assert
manager
.
activate_lora
(
3
)
assert
manager
.
lora_index_to_id
[
0
]
==
2
assert
manager
.
lora_index_to_id
[
0
]
==
2
assert
manager
.
lora_index_to_id
[
1
]
==
3
assert
manager
.
lora_index_to_id
[
1
]
==
3
assert
manager
.
pin_lora
(
2
)
assert
manager
.
lora_index_to_id
[
0
]
==
2
assert
manager
.
lora_index_to_id
[
1
]
==
3
assert
manager
.
activate_lora
(
1
)
assert
manager
.
lora_index_to_id
[
0
]
==
2
assert
manager
.
lora_index_to_id
[
1
]
==
1
assert
manager
.
deactivate_lora
(
2
)
assert
manager
.
lora_index_to_id
[
0
]
is
None
assert
manager
.
lora_index_to_id
[
1
]
==
1
assert
manager
.
activate_lora
(
3
)
assert
manager
.
lora_index_to_id
[
0
]
==
3
assert
manager
.
lora_index_to_id
[
1
]
==
1
assert
manager
.
pin_lora
(
3
)
assert
manager
.
pin_lora
(
1
)
with
pytest
.
raises
(
RuntimeError
):
assert
manager
.
pin_lora
(
2
)
assert
manager
.
lora_index_to_id
[
0
]
==
3
assert
manager
.
lora_index_to_id
[
1
]
==
1
with
pytest
.
raises
(
RuntimeError
):
assert
manager
.
activate_lora
(
2
)
assert
manager
.
deactivate_lora
(
3
)
assert
manager
.
pin_lora
(
2
)
assert
manager
.
lora_index_to_id
[
0
]
==
2
assert
manager
.
lora_index_to_id
[
1
]
==
1
assert
manager
.
remove_lora
(
3
)
with
pytest
.
raises
(
ValueError
):
assert
manager
.
pin_lora
(
3
)
def
test_lru_lora_model_manager
(
dist_init
,
dummy_model
):
def
test_lru_lora_model_manager
(
dist_init
,
dummy_model
):
...
@@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
...
@@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert
set
(
manager
.
list_loras
())
==
set
()
assert
set
(
manager
.
list_loras
())
==
set
()
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
# pinning
assert
manager
.
add_lora
(
model_lora3
)
assert
manager
.
activate_lora
(
3
)
assert
manager
.
add_lora
(
model_lora4
)
assert
manager
.
activate_lora
(
4
)
assert
set
(
manager
.
list_loras
())
==
{
3
,
4
}
with
pytest
.
raises
(
ValueError
):
assert
manager
.
pin_lora
(
1
)
assert
manager
.
pin_lora
(
3
)
# Remove manually
assert
manager
.
remove_lora
(
3
)
assert
not
manager
.
remove_lora
(
3
)
assert
set
(
manager
.
list_loras
())
==
{
4
}
assert
manager
.
lora_index_to_id
[
0
]
is
None
assert
manager
.
lora_index_to_id
[
1
]
==
4
assert
manager
.
add_lora
(
model_lora1
)
assert
manager
.
pin_lora
(
1
)
assert
manager
.
add_lora
(
model_lora2
)
assert
manager
.
activate_lora
(
2
)
assert
set
(
manager
.
list_loras
())
==
{
1
,
2
}
assert
manager
.
lora_index_to_id
[
0
]
==
1
assert
manager
.
lora_index_to_id
[
1
]
==
2
assert
manager
.
remove_oldest_lora
()
assert
set
(
manager
.
list_loras
())
==
{
1
}
assert
manager
.
lora_index_to_id
[
0
]
==
1
assert
manager
.
lora_index_to_id
[
1
]
is
None
with
pytest
.
raises
(
RuntimeError
):
assert
manager
.
remove_oldest_lora
()
assert
set
(
manager
.
list_loras
())
==
{
1
}
def
test_lru_cache_worker_lora_manager
(
llama_2_7b_model_extra_embeddings
,
def
test_lru_cache_worker_lora_manager
(
llama_2_7b_model_extra_embeddings
,
sql_lora_files
):
sql_lora_files
):
...
...
vllm/engine/llm_engine.py
View file @
f5dda63e
...
@@ -1009,6 +1009,9 @@ class LLMEngine:
...
@@ -1009,6 +1009,9 @@ class LLMEngine:
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
model_executor
.
list_loras
()
return
self
.
model_executor
.
list_loras
()
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
self
.
model_executor
.
check_health
()
self
.
model_executor
.
check_health
()
...
...
vllm/executor/cpu_executor.py
View file @
f5dda63e
...
@@ -84,6 +84,9 @@ class CPUExecutor(ExecutorBase):
...
@@ -84,6 +84,9 @@ class CPUExecutor(ExecutorBase):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
driver_worker
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
driver_worker
.
list_loras
()
return
self
.
driver_worker
.
list_loras
()
...
...
vllm/executor/distributed_gpu_executor.py
View file @
f5dda63e
...
@@ -100,6 +100,13 @@ class DistributedGPUExecutor(GPUExecutor):
...
@@ -100,6 +100,13 @@ class DistributedGPUExecutor(GPUExecutor):
lora_id
=
lora_id
,
lora_id
=
lora_id
,
)
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"pin_lora"
,
lora_id
=
lora_id
,
)
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
return
self
.
_run_workers
(
"list_loras"
)
...
...
vllm/executor/executor_base.py
View file @
f5dda63e
...
@@ -86,6 +86,10 @@ class ExecutorBase(ABC):
...
@@ -86,6 +86,10 @@ class ExecutorBase(ABC):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
# type: ignore
@
abstractmethod
@
abstractmethod
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/executor/gpu_executor.py
View file @
f5dda63e
...
@@ -99,6 +99,10 @@ class GPUExecutor(ExecutorBase):
...
@@ -99,6 +99,10 @@ class GPUExecutor(ExecutorBase):
assert
lora_id
>
0
,
"lora_id must be greater than 0."
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
driver_worker
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
driver_worker
.
list_loras
()
return
self
.
driver_worker
.
list_loras
()
...
...
vllm/executor/neuron_executor.py
View file @
f5dda63e
...
@@ -65,6 +65,9 @@ class NeuronExecutor(ExecutorBase):
...
@@ -65,6 +65,9 @@ class NeuronExecutor(ExecutorBase):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
return
self
.
driver_worker
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
driver_worker
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
driver_worker
.
list_loras
()
return
self
.
driver_worker
.
list_loras
()
...
...
vllm/lora/models.py
View file @
f5dda63e
...
@@ -525,6 +525,12 @@ class LoRAModelManager:
...
@@ -525,6 +525,12 @@ class LoRAModelManager:
self
.
long_lora_context
.
offsets_by_lora_id
.
pop
(
lora_id
,
None
)
self
.
long_lora_context
.
offsets_by_lora_id
.
pop
(
lora_id
,
None
)
return
bool
(
self
.
_registered_loras
.
pop
(
lora_id
,
None
))
return
bool
(
self
.
_registered_loras
.
pop
(
lora_id
,
None
))
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Pin a LoRAModel in the manager cache."""
raise
NotImplementedError
(
"Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning"
)
# type: ignore
# TODO see if this can be vectorized
# TODO see if this can be vectorized
def
_set_lora_mapping
(
self
,
mapping
:
LoRAMapping
)
->
None
:
def
_set_lora_mapping
(
self
,
mapping
:
LoRAMapping
)
->
None
:
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
...
@@ -777,6 +783,26 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
...
@@ -777,6 +783,26 @@ class LRUCacheLoRAModelManager(LoRAModelManager):
return
True
return
True
return
False
return
False
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
"""Pin a LoRAModel in the manager cache."""
self
.
_pin_lora_in_cpu_cache
(
lora_id
)
self
.
_pin_lora_in_gpu_cache
(
lora_id
)
return
True
def
_pin_lora_in_cpu_cache
(
self
,
lora_id
:
int
):
try
:
self
.
_registered_loras
.
pin
(
lora_id
)
except
ValueError
as
err
:
raise
ValueError
(
"Pinning failed. "
f
"LoRA
{
lora_id
}
is not registered."
)
from
err
def
_pin_lora_in_gpu_cache
(
self
,
lora_id
:
int
):
if
lora_id
not
in
self
.
_active_loras
:
# move lora to gpu if not already active
self
.
activate_lora
(
lora_id
)
self
.
_active_loras
.
pin
(
lora_id
)
def
create_lora_manager
(
def
create_lora_manager
(
model
:
nn
.
Module
,
model
:
nn
.
Module
,
...
...
vllm/lora/worker_manager.py
View file @
f5dda63e
...
@@ -221,6 +221,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -221,6 +221,9 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_lora_manager
.
remove_lora
(
lora_id
)
return
self
.
_lora_manager
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_lora_manager
.
pin_lora
(
lora_id
)
def
remove_all_loras
(
self
):
def
remove_all_loras
(
self
):
self
.
_lora_manager
.
remove_all_loras
()
self
.
_lora_manager
.
remove_all_loras
()
...
...
vllm/utils.py
View file @
f5dda63e
...
@@ -15,7 +15,7 @@ from collections import defaultdict
...
@@ -15,7 +15,7 @@ from collections import defaultdict
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
AsyncIterator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
from
typing
import
(
Any
,
AsyncIterator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Hashable
,
List
,
Optional
,
OrderedDict
,
Tuple
,
TypeVar
,
Hashable
,
List
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
TypeVar
,
Union
)
Union
)
import
numpy
as
np
import
numpy
as
np
...
@@ -44,6 +44,13 @@ K = TypeVar("K")
...
@@ -44,6 +44,13 @@ K = TypeVar("K")
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
class
_Sentinel
:
...
ALL_PINNED_SENTINEL
=
_Sentinel
()
class
Device
(
enum
.
Enum
):
class
Device
(
enum
.
Enum
):
GPU
=
enum
.
auto
()
GPU
=
enum
.
auto
()
CPU
=
enum
.
auto
()
CPU
=
enum
.
auto
()
...
@@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
...
@@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
def
__init__
(
self
,
capacity
:
int
):
def
__init__
(
self
,
capacity
:
int
):
self
.
cache
:
OrderedDict
[
Hashable
,
T
]
=
OrderedDict
()
self
.
cache
:
OrderedDict
[
Hashable
,
T
]
=
OrderedDict
()
self
.
pinned_items
:
Set
[
Hashable
]
=
set
()
self
.
capacity
=
capacity
self
.
capacity
=
capacity
def
__contains__
(
self
,
key
:
Hashable
)
->
bool
:
def
__contains__
(
self
,
key
:
Hashable
)
->
bool
:
...
@@ -102,14 +110,36 @@ class LRUCache(Generic[T]):
...
@@ -102,14 +110,36 @@ class LRUCache(Generic[T]):
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
self
.
_remove_old_if_needed
()
self
.
_remove_old_if_needed
()
def
pin
(
self
,
key
:
Hashable
)
->
None
:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if
key
not
in
self
.
cache
:
raise
ValueError
(
f
"Cannot pin key:
{
key
}
not in cache."
)
self
.
pinned_items
.
add
(
key
)
def
_unpin
(
self
,
key
:
Hashable
)
->
None
:
self
.
pinned_items
.
remove
(
key
)
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
]):
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
]):
pass
pass
def
remove_oldest
(
self
):
def
remove_oldest
(
self
,
remove_pinned
=
False
):
if
not
self
.
cache
:
if
not
self
.
cache
:
return
return
key
,
value
=
self
.
cache
.
popitem
(
last
=
False
)
self
.
_on_remove
(
key
,
value
)
if
not
remove_pinned
:
# pop the oldest item in the cache that is not pinned
lru_key
=
next
(
(
key
for
key
in
self
.
cache
if
key
not
in
self
.
pinned_items
),
ALL_PINNED_SENTINEL
)
if
lru_key
is
ALL_PINNED_SENTINEL
:
raise
RuntimeError
(
"All items are pinned, "
"cannot remove oldest from the cache."
)
else
:
lru_key
=
next
(
iter
(
self
.
cache
))
self
.
pop
(
lru_key
)
def
_remove_old_if_needed
(
self
)
->
None
:
def
_remove_old_if_needed
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
...
@@ -120,13 +150,16 @@ class LRUCache(Generic[T]):
...
@@ -120,13 +150,16 @@ class LRUCache(Generic[T]):
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
run_on_remove
=
key
in
self
.
cache
run_on_remove
=
key
in
self
.
cache
value
:
Optional
[
T
]
=
self
.
cache
.
pop
(
key
,
default_value
)
value
:
Optional
[
T
]
=
self
.
cache
.
pop
(
key
,
default_value
)
# remove from pinned items
if
key
in
self
.
pinned_items
:
self
.
_unpin
(
key
)
if
run_on_remove
:
if
run_on_remove
:
self
.
_on_remove
(
key
,
value
)
self
.
_on_remove
(
key
,
value
)
return
value
return
value
def
clear
(
self
):
def
clear
(
self
):
while
len
(
self
.
cache
)
>
0
:
while
len
(
self
.
cache
)
>
0
:
self
.
remove_oldest
()
self
.
remove_oldest
(
remove_pinned
=
True
)
self
.
cache
.
clear
()
self
.
cache
.
clear
()
...
...
vllm/worker/model_runner.py
View file @
f5dda63e
...
@@ -878,6 +878,11 @@ class ModelRunner:
...
@@ -878,6 +878,11 @@ class ModelRunner:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_lora
(
lora_id
)
return
self
.
lora_manager
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
if
not
self
.
lora_manager
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
raise
RuntimeError
(
"LoRA is not enabled."
)
...
...
vllm/worker/worker.py
View file @
f5dda63e
...
@@ -333,6 +333,9 @@ class Worker(WorkerBase):
...
@@ -333,6 +333,9 @@ class Worker(WorkerBase):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
remove_lora
(
lora_id
)
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
model_runner
.
list_loras
()
return
self
.
model_runner
.
list_loras
()
...
...
vllm/worker/worker_base.py
View file @
f5dda63e
...
@@ -70,6 +70,10 @@ class WorkerBase(ABC):
...
@@ -70,6 +70,10 @@ class WorkerBase(ABC):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -86,6 +90,10 @@ class LoraNotSupportedWorkerBase(WorkerBase):
...
@@ -86,6 +90,10 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
# type: ignore
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
raise
ValueError
(
f
"
{
type
(
self
)
}
does not support LoRA"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment