Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7f5edb59
Unverified
Commit
7f5edb59
authored
Nov 12, 2024
by
Jee Jee Li
Committed by
GitHub
Nov 12, 2024
Browse files
[Misc][LoRA] Replace hardcoded cuda device with configurable argument (#10223)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
eea55cca
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
174 additions
and
80 deletions
+174
-80
tests/lora/test_layers.py
tests/lora/test_layers.py
+36
-20
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+113
-40
tests/lora/utils.py
tests/lora/utils.py
+5
-4
vllm/lora/models.py
vllm/lora/models.py
+9
-10
vllm/lora/punica.py
vllm/lora/punica.py
+9
-6
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+2
-0
No files found.
tests/lora/test_layers.py
View file @
7f5edb59
...
@@ -51,6 +51,7 @@ TOLERANCES = {
...
@@ -51,6 +51,7 @@ TOLERANCES = {
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
# We will launch different triton kernels between the prefill and decode
# We will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES
=
[
True
,
False
]
STAGES
=
[
True
,
False
]
...
@@ -120,7 +121,8 @@ def populate_loras(
...
@@ -120,7 +121,8 @@ def populate_loras(
subloras
:
List
[
LoRALayerWeights
]
=
[]
subloras
:
List
[
LoRALayerWeights
]
=
[]
sublora_len
=
layer_weights
.
shape
[
0
]
//
repeats
sublora_len
=
layer_weights
.
shape
[
0
]
//
repeats
for
i
in
range
(
repeats
):
for
i
in
range
(
repeats
):
sublora
=
DummyLoRAManager
().
init_random_lora
(
sublora
=
DummyLoRAManager
(
layer_weights
.
device
).
init_random_lora
(
module_name
=
f
"fake_
{
i
}
"
,
module_name
=
f
"fake_
{
i
}
"
,
weight
=
layer_weights
,
weight
=
layer_weights
,
generate_embeddings_tensor
=
generate_embeddings_tensor
,
generate_embeddings_tensor
=
generate_embeddings_tensor
,
...
@@ -152,6 +154,7 @@ def create_random_inputs(
...
@@ -152,6 +154,7 @@ def create_random_inputs(
input_size
:
Tuple
[
int
,
...],
input_size
:
Tuple
[
int
,
...],
input_range
:
Tuple
[
float
,
float
],
input_range
:
Tuple
[
float
,
float
],
input_type
:
torch
.
dtype
=
torch
.
int
,
input_type
:
torch
.
dtype
=
torch
.
int
,
device
:
torch
.
device
=
"cuda"
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
int
],
List
[
int
]]:
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
int
],
List
[
int
]]:
"""Creates random inputs.
"""Creates random inputs.
...
@@ -173,10 +176,14 @@ def create_random_inputs(
...
@@ -173,10 +176,14 @@ def create_random_inputs(
for
_
in
range
(
num_inputs
):
for
_
in
range
(
num_inputs
):
if
input_type
==
torch
.
int
:
if
input_type
==
torch
.
int
:
inputs
.
append
(
inputs
.
append
(
torch
.
randint
(
low
=
int
(
low
),
high
=
int
(
high
),
size
=
input_size
))
torch
.
randint
(
low
=
int
(
low
),
high
=
int
(
high
),
size
=
input_size
,
device
=
device
))
else
:
else
:
inputs
.
append
(
inputs
.
append
(
torch
.
rand
(
size
=
input_size
,
dtype
=
input_type
)
*
high
+
low
)
torch
.
rand
(
size
=
input_size
,
dtype
=
input_type
,
device
=
device
)
*
high
+
low
)
lora_id
=
random
.
choice
(
active_lora_ids
)
lora_id
=
random
.
choice
(
active_lora_ids
)
index_mapping
+=
[
lora_id
]
*
input_size
[
0
]
index_mapping
+=
[
lora_id
]
*
input_size
[
0
]
...
@@ -191,6 +198,10 @@ def create_random_inputs(
...
@@ -191,6 +198,10 @@ def create_random_inputs(
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
512
,
32000
,
64000
,
128000
])
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
def
test_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
...
@@ -225,7 +236,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -225,7 +236,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
num_inputs
=
num_loras
*
3
,
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_size
=
(
200
,
),
input_range
=
(
1
,
vocab_size
),
input_range
=
(
1
,
vocab_size
),
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -263,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -263,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
num_inputs
=
num_loras
*
3
,
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_size
=
(
200
,
),
input_range
=
(
1
,
vocab_size
),
input_range
=
(
1
,
vocab_size
),
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -291,6 +302,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -291,6 +302,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
def
test_embeddings_with_new_embeddings
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
vocab_size
,
stage
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
...
@@ -345,7 +357,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -345,7 +357,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
num_inputs
=
num_loras
*
3
,
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_size
=
(
200
,
),
input_range
=
(
1
,
vocab_size
),
input_range
=
(
1
,
vocab_size
),
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -400,7 +412,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -400,7 +412,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
num_inputs
=
num_loras
*
3
,
num_inputs
=
num_loras
*
3
,
input_size
=
(
200
,
),
input_size
=
(
200
,
),
input_range
=
(
1
,
vocab_size
),
input_range
=
(
1
,
vocab_size
),
)
device
=
device
)
original_inputs
=
deepcopy
(
inputs
)
original_inputs
=
deepcopy
(
inputs
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
...
@@ -426,6 +438,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -426,6 +438,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
,
def
test_lm_head_logits_processor
(
dist_init
,
num_loras
,
device
,
vocab_size
,
stage
)
->
None
:
stage
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
...
@@ -471,7 +484,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -471,7 +484,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
input_size
=
(
1
,
1024
),
input_size
=
(
1
,
1024
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -520,7 +533,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -520,7 +533,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
input_size
=
(
1
,
1024
),
input_size
=
(
1
,
1024
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -554,6 +567,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -554,6 +567,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
@
pytest
.
mark
.
parametrize
(
"stage"
,
STAGES
)
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
)
->
None
:
def
test_linear_replicated
(
dist_init
,
num_loras
,
device
,
stage
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
max_loras
=
8
max_loras
=
8
...
@@ -592,7 +606,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
...
@@ -592,7 +606,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -631,7 +645,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
...
@@ -631,7 +645,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -658,6 +672,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
...
@@ -658,6 +672,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
def
test_linear_parallel
(
dist_init
,
num_loras
,
orientation
,
fully_shard
,
device
,
stage
)
->
None
:
device
,
stage
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
max_loras
=
8
max_loras
=
8
...
@@ -706,7 +721,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -706,7 +721,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -745,7 +760,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -745,7 +760,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -772,6 +787,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -772,6 +787,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
def
test_column_parallel_packed
(
dist_init
,
num_loras
,
repeats
,
fully_shard
,
device
,
stage
)
->
None
:
device
,
stage
)
->
None
:
torch
.
cuda
.
set_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
punica_wrapper
=
PunicaWrapper
(
8192
,
256
,
device
)
max_loras
=
8
max_loras
=
8
...
@@ -842,7 +858,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...
@@ -842,7 +858,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -883,7 +899,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...
@@ -883,7 +899,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
input_size
=
(
1
,
4096
),
input_size
=
(
1
,
4096
),
input_range
=
(
0
,
1
),
input_range
=
(
0
,
1
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
,
prompt_mapping
,
is_prefill
=
stage
)
is_prefill
=
stage
)
...
@@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
...
@@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
input_size
=
(
1
,
max_position
),
input_size
=
(
1
,
max_position
),
input_range
=
(
0
,
lora_config
.
lora_extra_vocab_size
),
input_range
=
(
0
,
lora_config
.
lora_extra_vocab_size
),
input_type
=
torch
.
float16
,
input_type
=
torch
.
float16
,
)
device
=
device
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
lora_mapping
=
LoRAMapping
(
index_mapping
,
prompt_mapping
)
long_lora_context
=
LongContextLoRAContext
(
list
(
scaling_factors
),
long_lora_context
=
LongContextLoRAContext
(
list
(
scaling_factors
),
...
...
tests/lora/test_lora_manager.py
View file @
7f5edb59
...
@@ -25,8 +25,13 @@ EMBEDDING_MODULES = {
...
@@ -25,8 +25,13 @@ EMBEDDING_MODULES = {
EMBEDDING_PADDING_MODULES
=
[
"lm_head"
]
EMBEDDING_PADDING_MODULES
=
[
"lm_head"
]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
def
test_from_lora_tensors
(
sql_lora_files
):
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_from_lora_tensors
(
sql_lora_files
,
device
):
tensors
=
load_file
(
tensors
=
load_file
(
os
.
path
.
join
(
sql_lora_files
,
"adapter_model.safetensors"
))
os
.
path
.
join
(
sql_lora_files
,
"adapter_model.safetensors"
))
new_embeddings
=
load_file
(
new_embeddings
=
load_file
(
...
@@ -36,7 +41,7 @@ def test_from_lora_tensors(sql_lora_files):
...
@@ -36,7 +41,7 @@ def test_from_lora_tensors(sql_lora_files):
8
,
8
,
16
,
16
,
tensors
,
tensors
,
"cuda"
,
device
,
embeddings
=
new_embeddings
,
embeddings
=
new_embeddings
,
embedding_modules
=
EMBEDDING_MODULES
,
embedding_modules
=
EMBEDDING_MODULES
,
embedding_padding_modules
=
EMBEDDING_PADDING_MODULES
)
embedding_padding_modules
=
EMBEDDING_PADDING_MODULES
)
...
@@ -46,6 +51,8 @@ def test_from_lora_tensors(sql_lora_files):
...
@@ -46,6 +51,8 @@ def test_from_lora_tensors(sql_lora_files):
assert
lora
.
lora_alpha
==
16
assert
lora
.
lora_alpha
==
16
assert
lora
.
lora_a
is
not
None
assert
lora
.
lora_a
is
not
None
assert
lora
.
lora_b
is
not
None
assert
lora
.
lora_b
is
not
None
assert
lora
.
lora_a
.
device
==
torch
.
device
(
device
)
assert
lora
.
lora_b
.
device
==
torch
.
device
(
device
)
assert
(
lora
.
lora_a
.
shape
[
1
]
==
lora
.
lora_b
.
shape
[
0
]
assert
(
lora
.
lora_a
.
shape
[
1
]
==
lora
.
lora_b
.
shape
[
0
]
),
f
"
{
lora
.
lora_a
.
shape
=
}
,
{
lora
.
lora_b
.
shape
=
}
"
),
f
"
{
lora
.
lora_a
.
shape
=
}
,
{
lora
.
lora_b
.
shape
=
}
"
assert
lora
.
lora_a
.
shape
[
1
]
==
8
assert
lora
.
lora_a
.
shape
[
1
]
==
8
...
@@ -60,8 +67,8 @@ def test_from_lora_tensors(sql_lora_files):
...
@@ -60,8 +67,8 @@ def test_from_lora_tensors(sql_lora_files):
assert
lora
.
embeddings_tensor
is
None
assert
lora
.
embeddings_tensor
is
None
def
create_lora
(
lora_id
:
int
,
model
:
nn
.
Module
,
def
create_lora
(
lora_id
:
int
,
model
:
nn
.
Module
,
sub_modules
:
List
[
str
],
sub_modules
:
List
[
str
]
)
->
LoRAModel
:
device
:
torch
.
device
)
->
LoRAModel
:
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
for
name
in
sub_modules
:
for
name
in
sub_modules
:
w
=
model
.
get_submodule
(
name
).
weight
w
=
model
.
get_submodule
(
name
).
weight
...
@@ -69,8 +76,8 @@ def create_lora(lora_id: int, model: nn.Module,
...
@@ -69,8 +76,8 @@ def create_lora(lora_id: int, model: nn.Module,
name
,
name
,
8
,
8
,
16
,
16
,
torch
.
rand
([
w
.
shape
[
1
],
8
],
device
=
"cuda"
),
torch
.
rand
([
w
.
shape
[
1
],
8
],
device
=
device
),
torch
.
rand
([
8
,
w
.
shape
[
0
]],
device
=
"cuda"
),
torch
.
rand
([
8
,
w
.
shape
[
0
]],
device
=
device
),
)
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
...
@@ -80,6 +87,7 @@ def create_packed_lora(
...
@@ -80,6 +87,7 @@ def create_packed_lora(
model
:
nn
.
Module
,
model
:
nn
.
Module
,
module_name
,
module_name
,
replaced_module_names
,
replaced_module_names
,
device
:
torch
.
device
,
empty_replaced_module_name
=
None
,
empty_replaced_module_name
=
None
,
)
->
LoRAModel
:
)
->
LoRAModel
:
w
=
model
.
get_submodule
(
module_name
).
weight
w
=
model
.
get_submodule
(
module_name
).
weight
...
@@ -91,9 +99,9 @@ def create_packed_lora(
...
@@ -91,9 +99,9 @@ def create_packed_lora(
replaced_module_name
,
replaced_module_name
,
8
,
8
,
16
,
16
,
torch
.
rand
([
w
.
shape
[
1
],
8
],
device
=
"cuda"
),
torch
.
rand
([
w
.
shape
[
1
],
8
],
device
=
device
),
torch
.
rand
([
8
,
w
.
shape
[
0
]
//
len
(
replaced_module_names
)],
torch
.
rand
([
8
,
w
.
shape
[
0
]
//
len
(
replaced_module_names
)],
device
=
"cuda"
),
device
=
device
),
)
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
return
LoRAModel
(
lora_id
,
8
,
loras
)
...
@@ -104,7 +112,8 @@ def test_replace_submodules(dist_init, dummy_model):
...
@@ -104,7 +112,8 @@ def test_replace_submodules(dist_init, dummy_model):
model
.
packed_modules_mapping
=
{}
model
.
packed_modules_mapping
=
{}
manager
=
LoRAModelManager
(
manager
=
LoRAModelManager
(
model
,
1
,
1
,
1
,
model
,
1
,
1
,
1
,
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
8
,
max_loras
=
8
))
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
8
,
max_loras
=
8
),
torch
.
device
(
"cuda"
))
model
=
manager
.
model
model
=
manager
.
model
assert
isinstance
(
model
.
get_submodule
(
"dense1"
),
assert
isinstance
(
model
.
get_submodule
(
"dense1"
),
...
@@ -116,16 +125,28 @@ def test_replace_submodules(dist_init, dummy_model):
...
@@ -116,16 +125,28 @@ def test_replace_submodules(dist_init, dummy_model):
RowParallelLinearWithLoRA
)
RowParallelLinearWithLoRA
)
def
test_lora_model_manager
(
dist_init
,
dummy_model
):
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_lora_model_manager
(
dist_init
,
dummy_model
,
device
):
model
=
dummy_model
model
=
dummy_model
model
.
supported_lora_modules
=
[
"dense1"
,
"dense2"
,
"lm_head"
]
model
.
supported_lora_modules
=
[
"dense1"
,
"dense2"
,
"lm_head"
]
model
.
packed_modules_mapping
=
{}
model
.
packed_modules_mapping
=
{}
model_lora1
=
create_lora
(
1
,
model
,
[
"layer1.dense1"
,
"dense2"
,
"lm_head"
])
model_lora1
=
create_lora
(
1
,
model_lora2
=
create_lora
(
2
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
])
model
,
[
"layer1.dense1"
,
"dense2"
,
"lm_head"
],
model_lora3
=
create_lora
(
3
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
])
device
=
device
)
manager
=
LoRAModelManager
(
model_lora2
=
create_lora
(
2
,
model
,
2
,
2
,
2
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
],
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
3
,
max_loras
=
2
))
device
=
device
)
model_lora3
=
create_lora
(
3
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
],
device
=
device
)
manager
=
LoRAModelManager
(
model
,
2
,
2
,
2
,
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
3
,
max_loras
=
2
),
device
=
device
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
assert
manager
.
add_adapter
(
model_lora1
)
assert
manager
.
add_adapter
(
model_lora1
)
assert
manager
.
activate_adapter
(
1
)
assert
manager
.
activate_adapter
(
1
)
...
@@ -161,17 +182,32 @@ def test_lora_model_manager(dist_init, dummy_model):
...
@@ -161,17 +182,32 @@ def test_lora_model_manager(dist_init, dummy_model):
assert
manager
.
lora_index_to_id
[
0
]
==
3
assert
manager
.
lora_index_to_id
[
0
]
==
3
assert
manager
.
lora_index_to_id
[
1
]
==
2
assert
manager
.
lora_index_to_id
[
1
]
==
2
assert
manager
.
device
==
device
assert
manager
.
punica_wrapper
.
device
==
device
def
test_lora_lru_cache_model_manager
(
dist_init
,
dummy_model
):
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_lora_lru_cache_model_manager
(
dist_init
,
dummy_model
,
device
):
model
=
dummy_model
model
=
dummy_model
model
.
supported_lora_modules
=
[
"dense1"
,
"dense2"
,
"lm_head"
]
model
.
supported_lora_modules
=
[
"dense1"
,
"dense2"
,
"lm_head"
]
model
.
packed_modules_mapping
=
{}
model
.
packed_modules_mapping
=
{}
model_lora1
=
create_lora
(
1
,
model
,
[
"layer1.dense1"
,
"dense2"
,
"lm_head"
])
model_lora1
=
create_lora
(
1
,
model_lora2
=
create_lora
(
2
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
])
model
,
[
"layer1.dense1"
,
"dense2"
,
"lm_head"
],
model_lora3
=
create_lora
(
3
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
])
device
=
device
)
manager
=
LRUCacheLoRAModelManager
(
model_lora2
=
create_lora
(
2
,
model
,
2
,
2
,
2
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
],
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
3
,
max_loras
=
2
))
device
=
device
)
model_lora3
=
create_lora
(
3
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
],
device
=
device
)
manager
=
LRUCacheLoRAModelManager
(
model
,
2
,
2
,
2
,
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
3
,
max_loras
=
2
),
device
=
device
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
assert
manager
.
add_adapter
(
model_lora1
)
assert
manager
.
add_adapter
(
model_lora1
)
assert
manager
.
activate_adapter
(
1
)
assert
manager
.
activate_adapter
(
1
)
...
@@ -238,20 +274,37 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
...
@@ -238,20 +274,37 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
assert
manager
.
pin_adapter
(
3
)
assert
manager
.
pin_adapter
(
3
)
assert
manager
.
punica_wrapper
.
device
==
device
assert
manager
.
device
==
device
def
test_lru_lora_model_manager
(
dist_init
,
dummy_model
):
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_lru_lora_model_manager
(
dist_init
,
dummy_model
,
device
):
# This tests just the LRU cache functionality, everything else is
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
# tested in test_lora_model_manager
model
=
dummy_model
model
=
dummy_model
model
.
supported_lora_modules
=
[
"dense1"
,
"dense2"
,
"lm_head"
]
model
.
supported_lora_modules
=
[
"dense1"
,
"dense2"
,
"lm_head"
]
model
.
packed_modules_mapping
=
{}
model
.
packed_modules_mapping
=
{}
model_lora1
=
create_lora
(
1
,
model
,
[
"layer1.dense1"
,
"dense2"
,
"lm_head"
])
model_lora1
=
create_lora
(
1
,
model_lora2
=
create_lora
(
2
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
])
model
,
[
"layer1.dense1"
,
"dense2"
,
"lm_head"
],
model_lora3
=
create_lora
(
3
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
])
device
=
device
)
model_lora4
=
create_lora
(
4
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
])
model_lora2
=
create_lora
(
2
,
manager
=
LRUCacheLoRAModelManager
(
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
],
model
,
2
,
2
,
2
,
device
=
device
)
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
2
,
max_loras
=
2
))
model_lora3
=
create_lora
(
3
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
],
device
=
device
)
model_lora4
=
create_lora
(
4
,
model
,
[
"dense1"
,
"dense2"
,
"lm_head"
],
device
=
device
)
manager
=
LRUCacheLoRAModelManager
(
model
,
2
,
2
,
2
,
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
2
,
max_loras
=
2
),
device
=
device
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
assert
all
(
x
is
None
for
x
in
manager
.
lora_index_to_id
)
...
@@ -351,14 +404,17 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
...
@@ -351,14 +404,17 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert
manager
.
remove_oldest_adapter
()
assert
manager
.
remove_oldest_adapter
()
assert
set
(
manager
.
list_adapters
())
==
{
1
}
assert
set
(
manager
.
list_adapters
())
==
{
1
}
assert
manager
.
punica_wrapper
.
device
==
device
assert
manager
.
device
==
device
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_lru_cache_worker_adapter_manager
(
llama_2_7b_model_extra_embeddings
,
def
test_lru_cache_worker_adapter_manager
(
llama_2_7b_model_extra_embeddings
,
sql_lora_files
):
sql_lora_files
,
device
):
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
4
,
max_loras
=
4
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
4
,
max_loras
=
4
)
worker_adapter_manager
=
LRUCacheWorkerLoRAManager
(
worker_adapter_manager
=
LRUCacheWorkerLoRAManager
(
4
,
2
,
llama_2_7b_model_extra_embeddings
.
unpadded_vocab_size
-
4
,
2
,
llama_2_7b_model_extra_embeddings
.
unpadded_vocab_size
-
lora_config
.
lora_extra_vocab_size
,
lora_config
,
torch
.
device
(
"cuda"
)
,
lora_config
.
lora_extra_vocab_size
,
lora_config
,
device
,
EMBEDDING_MODULES
,
EMBEDDING_PADDING_MODULES
)
EMBEDDING_MODULES
,
EMBEDDING_PADDING_MODULES
)
worker_adapter_manager
.
create_lora_manager
(
worker_adapter_manager
.
create_lora_manager
(
llama_2_7b_model_extra_embeddings
)
llama_2_7b_model_extra_embeddings
)
...
@@ -426,14 +482,19 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
...
@@ -426,14 +482,19 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
LoRARequest
(
"14"
,
14
,
sql_lora_files
)
LoRARequest
(
"14"
,
14
,
sql_lora_files
)
],
mapping
)
],
mapping
)
assert
worker_adapter_manager
.
device
==
device
assert
(
worker_adapter_manager
.
_adapter_manager
.
punica_wrapper
.
device
==
device
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_worker_adapter_manager
(
llama_2_7b_model_extra_embeddings
,
def
test_worker_adapter_manager
(
llama_2_7b_model_extra_embeddings
,
sql_lora_files
):
sql_lora_files
,
device
):
# Should remove every LoRA not specified in the request.
# Should remove every LoRA not specified in the request.
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
4
,
max_loras
=
4
)
lora_config
=
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
4
,
max_loras
=
4
)
worker_adapter_manager
=
WorkerLoRAManager
(
worker_adapter_manager
=
WorkerLoRAManager
(
4
,
2
,
llama_2_7b_model_extra_embeddings
.
unpadded_vocab_size
-
4
,
2
,
llama_2_7b_model_extra_embeddings
.
unpadded_vocab_size
-
lora_config
.
lora_extra_vocab_size
,
lora_config
,
torch
.
device
(
"cuda"
)
,
lora_config
.
lora_extra_vocab_size
,
lora_config
,
device
,
EMBEDDING_MODULES
,
EMBEDDING_PADDING_MODULES
)
EMBEDDING_MODULES
,
EMBEDDING_PADDING_MODULES
)
worker_adapter_manager
.
create_lora_manager
(
worker_adapter_manager
.
create_lora_manager
(
llama_2_7b_model_extra_embeddings
)
llama_2_7b_model_extra_embeddings
)
...
@@ -497,8 +558,13 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
...
@@ -497,8 +558,13 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
LoRARequest
(
"14"
,
14
,
sql_lora_files
)
LoRARequest
(
"14"
,
14
,
sql_lora_files
)
],
mapping
)
],
mapping
)
assert
worker_adapter_manager
.
device
==
device
assert
(
worker_adapter_manager
.
_adapter_manager
.
punica_wrapper
.
device
==
device
)
def
test_packed_loras
(
dist_init
,
dummy_model_gate_up
):
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_packed_loras
(
dist_init
,
dummy_model_gate_up
,
device
):
model
=
dummy_model_gate_up
model
=
dummy_model_gate_up
model
.
supported_lora_modules
=
[
"gate_up_proj"
]
model
.
supported_lora_modules
=
[
"gate_up_proj"
]
model
.
packed_modules_mapping
=
{
model
.
packed_modules_mapping
=
{
...
@@ -511,18 +577,25 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
...
@@ -511,18 +577,25 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
1
,
1
,
model
,
model
,
module_name
=
"gate_up_proj"
,
module_name
=
"gate_up_proj"
,
replaced_module_names
=
[
"gate_proj"
,
"up_proj"
])
replaced_module_names
=
[
"gate_proj"
,
"up_proj"
],
device
=
device
)
model_lora1
=
create_packed_lora
(
model_lora1
=
create_packed_lora
(
2
,
2
,
model
,
model
,
module_name
=
"gate_up_proj"
,
module_name
=
"gate_up_proj"
,
replaced_module_names
=
[
"gate_proj"
,
"up_proj"
],
replaced_module_names
=
[
"gate_proj"
,
"up_proj"
],
device
=
device
,
empty_replaced_module_name
=
"gate_proj"
,
empty_replaced_module_name
=
"gate_proj"
,
)
)
manager
=
LoRAModelManager
(
manager
=
LoRAModelManager
(
model
,
model
,
2
,
2
,
2
,
2
,
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
2
,
max_loras
=
2
))
2
,
2
,
LoRAConfig
(
max_lora_rank
=
8
,
max_cpu_loras
=
2
,
max_loras
=
2
),
device
=
device
)
model
=
manager
.
model
model
=
manager
.
model
assert
isinstance
(
model
.
get_submodule
(
"gate_up_proj"
),
assert
isinstance
(
model
.
get_submodule
(
"gate_up_proj"
),
...
...
tests/lora/utils.py
View file @
7f5edb59
...
@@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
...
@@ -7,9 +7,10 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
class
DummyLoRAManager
:
class
DummyLoRAManager
:
def
__init__
(
self
):
def
__init__
(
self
,
device
:
torch
.
device
=
"cuda:0"
):
super
().
__init__
()
super
().
__init__
()
self
.
_loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
self
.
_loras
:
Dict
[
str
,
LoRALayerWeights
]
=
{}
self
.
_device
=
device
def
set_module_lora
(
self
,
module_name
:
str
,
lora
:
LoRALayerWeights
):
def
set_module_lora
(
self
,
module_name
:
str
,
lora
:
LoRALayerWeights
):
self
.
_loras
[
module_name
]
=
lora
self
.
_loras
[
module_name
]
=
lora
...
@@ -28,16 +29,16 @@ class DummyLoRAManager:
...
@@ -28,16 +29,16 @@ class DummyLoRAManager:
lora_alpha
=
1
,
lora_alpha
=
1
,
lora_a
=
torch
.
rand
([
weight
.
shape
[
1
],
rank
],
lora_a
=
torch
.
rand
([
weight
.
shape
[
1
],
rank
],
dtype
=
weight
.
dtype
,
dtype
=
weight
.
dtype
,
device
=
"cuda"
),
device
=
self
.
_device
),
lora_b
=
torch
.
rand
([
rank
,
weight
.
shape
[
0
]],
lora_b
=
torch
.
rand
([
rank
,
weight
.
shape
[
0
]],
dtype
=
weight
.
dtype
,
dtype
=
weight
.
dtype
,
device
=
"cuda"
),
device
=
self
.
_device
),
)
)
if
generate_embeddings_tensor
:
if
generate_embeddings_tensor
:
lora
.
embeddings_tensor
=
torch
.
rand
(
5
,
lora
.
embeddings_tensor
=
torch
.
rand
(
5
,
generate_embeddings_tensor
,
generate_embeddings_tensor
,
dtype
=
weight
.
dtype
,
dtype
=
weight
.
dtype
,
device
=
"cuda"
)
device
=
self
.
_device
)
self
.
set_module_lora
(
module_name
,
lora
)
self
.
set_module_lora
(
module_name
,
lora
)
return
lora
return
lora
...
...
vllm/lora/models.py
View file @
7f5edb59
...
@@ -301,6 +301,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -301,6 +301,7 @@ class LoRAModelManager(AdapterModelManager):
max_num_batched_tokens
:
int
,
max_num_batched_tokens
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
device
:
torch
.
device
,
):
):
"""Create a LoRAModelManager and adapter for a given model.
"""Create a LoRAModelManager and adapter for a given model.
...
@@ -314,6 +315,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -314,6 +315,7 @@ class LoRAModelManager(AdapterModelManager):
lora_config: the LoRA configuration.
lora_config: the LoRA configuration.
"""
"""
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
device
=
device
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
assert
self
.
capacity
>=
self
.
lora_slots
assert
self
.
capacity
>=
self
.
lora_slots
self
.
max_num_batched_tokens
=
math
.
ceil
(
max_num_batched_tokens
/
8
)
*
8
self
.
max_num_batched_tokens
=
math
.
ceil
(
max_num_batched_tokens
/
8
)
*
8
...
@@ -322,7 +324,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -322,7 +324,7 @@ class LoRAModelManager(AdapterModelManager):
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
long_lora_context
:
Optional
[
LongContextLoRAContext
]
=
None
self
.
punica_wrapper
=
PunicaWrapper
(
max_num_batched_tokens
,
self
.
punica_wrapper
=
PunicaWrapper
(
max_num_batched_tokens
,
max_batches
=
self
.
max_num_seqs
,
max_batches
=
self
.
max_num_seqs
,
device
=
"cuda"
)
device
=
self
.
device
)
# Scaling factor -> offset to the sin_cos_cache to it.
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
# Used for long context lora.
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
self
.
scaling_factor_to_offset
:
Dict
[
float
,
int
]
=
{}
...
@@ -653,16 +655,11 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
...
@@ -653,16 +655,11 @@ class LoRALRUCache(AdapterLRUCache[LoRAModel]):
class
LRUCacheLoRAModelManager
(
LoRAModelManager
):
class
LRUCacheLoRAModelManager
(
LoRAModelManager
):
"""A model manager that manages multiple LoRAs with LRU cache."""
"""A model manager that manages multiple LoRAs with LRU cache."""
def
__init__
(
def
__init__
(
self
,
model
:
nn
.
Module
,
max_num_seqs
:
int
,
self
,
max_num_batched_tokens
:
int
,
vocab_size
:
int
,
model
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
device
:
torch
.
device
):
max_num_seqs
:
int
,
max_num_batched_tokens
:
int
,
vocab_size
:
int
,
lora_config
:
LoRAConfig
,
):
super
().
__init__
(
model
,
max_num_seqs
,
max_num_batched_tokens
,
super
().
__init__
(
model
,
max_num_seqs
,
max_num_batched_tokens
,
vocab_size
,
lora_config
)
vocab_size
,
lora_config
,
device
)
self
.
_registered_adapters
:
LoRALRUCache
=
LoRALRUCache
(
self
.
_registered_adapters
:
LoRALRUCache
=
LoRALRUCache
(
self
.
capacity
,
self
.
deactivate_adapter
)
self
.
capacity
,
self
.
deactivate_adapter
)
self
.
_active_adapters
:
LoRALRUCache
=
LoRALRUCache
(
self
.
_active_adapters
:
LoRALRUCache
=
LoRALRUCache
(
...
@@ -732,6 +729,7 @@ def create_lora_manager(
...
@@ -732,6 +729,7 @@ def create_lora_manager(
max_num_batched_tokens
:
int
,
max_num_batched_tokens
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
device
:
torch
.
device
,
lora_manager_cls
:
Type
[
LoRAModelManager
]
=
LoRAModelManager
,
lora_manager_cls
:
Type
[
LoRAModelManager
]
=
LoRAModelManager
,
**
kwargs
)
->
LoRAModelManager
:
**
kwargs
)
->
LoRAModelManager
:
"""Create a LoRA adapter for a given model."""
"""Create a LoRA adapter for a given model."""
...
@@ -743,5 +741,6 @@ def create_lora_manager(
...
@@ -743,5 +741,6 @@ def create_lora_manager(
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_batched_tokens
=
max_num_batched_tokens
,
vocab_size
=
vocab_size
,
vocab_size
=
vocab_size
,
lora_config
=
lora_config
,
lora_config
=
lora_config
,
device
=
device
,
**
kwargs
)
**
kwargs
)
return
lora_manager
return
lora_manager
vllm/lora/punica.py
View file @
7f5edb59
...
@@ -62,6 +62,7 @@ def convert_mapping(
...
@@ -62,6 +62,7 @@ def convert_mapping(
max_loras
:
int
,
max_loras
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
extra_vocab_size
:
int
,
extra_vocab_size
:
int
,
device
:
torch
.
device
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
long_lora_context
:
Optional
[
"LongContextLoRAContext"
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
List
[
int
]]:
Optional
[
torch
.
Tensor
],
List
[
int
]]:
...
@@ -104,7 +105,7 @@ def convert_mapping(
...
@@ -104,7 +105,7 @@ def convert_mapping(
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
long_lora_offsets
:
Optional
[
torch
.
Tensor
]
=
None
if
long_lora_context
:
if
long_lora_context
:
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
long_lora_offsets
=
torch
.
zeros
(
len
(
index_mapping_indices
),
device
=
"cuda"
,
device
=
device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
prompt_mapping
:
List
[
int
]
=
[
prompt_mapping
:
List
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
...
@@ -131,10 +132,10 @@ def convert_mapping(
...
@@ -131,10 +132,10 @@ def convert_mapping(
if
long_lora_context
:
if
long_lora_context
:
assert
long_lora_offsets
is
not
None
assert
long_lora_offsets
is
not
None
indices_list
.
append
(
long_lora_offsets
)
indices_list
.
append
(
long_lora_offsets
)
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
indices
=
torch
.
tensor
(
indices_list
,
dtype
=
torch
.
long
,
device
=
device
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
d
evice
=
"cuda"
,
d
type
=
torch
.
long
,
d
type
=
torch
.
long
)
d
evice
=
device
)
embeddings_indices
=
torch
.
stack
([
embeddings_indices
=
torch
.
stack
([
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
),
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
),
...
@@ -145,7 +146,7 @@ def convert_mapping(
...
@@ -145,7 +146,7 @@ def convert_mapping(
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
=
torch
.
arange
(
sampler_indices_padded
=
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
+
(
0
,
len
(
sampler_indices_padded
),
device
=
device
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
))
sampler_indices_padded
*
len
(
sampler_indices_padded
))
long_lora_indices
=
None
long_lora_indices
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
long_lora_indices_len
:
Optional
[
int
]
=
None
...
@@ -183,7 +184,7 @@ class PunicaWrapper:
...
@@ -183,7 +184,7 @@ class PunicaWrapper:
"""
"""
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_batches
:
int
,
def
__init__
(
self
,
max_num_batched_tokens
:
int
,
max_batches
:
int
,
device
:
str
):
device
:
Union
[
torch
.
device
,
str
]
):
self
.
_token_lora_indices
=
torch
.
empty
(
max_num_batched_tokens
,
self
.
_token_lora_indices
=
torch
.
empty
(
max_num_batched_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
...
@@ -215,6 +216,7 @@ class PunicaWrapper:
...
@@ -215,6 +216,7 @@ class PunicaWrapper:
self
.
_lora_indices_per_batch
=
torch
.
empty
(
max_batches
,
self
.
_lora_indices_per_batch
=
torch
.
empty
(
max_batches
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
self
.
device
:
torch
.
device
=
device
self
.
max_length
:
int
=
0
self
.
max_length
:
int
=
0
self
.
token_nums
:
int
=
0
self
.
token_nums
:
int
=
0
self
.
batch_size
:
int
=
-
1
self
.
batch_size
:
int
=
-
1
...
@@ -263,6 +265,7 @@ class PunicaWrapper:
...
@@ -263,6 +265,7 @@ class PunicaWrapper:
max_loras
,
max_loras
,
vocab_size
,
vocab_size
,
extra_vocab_size
,
extra_vocab_size
,
self
.
device
,
long_lora_context
,
long_lora_context
,
)
)
self
.
_token_lora_indices
[:
base_indices
.
shape
[
0
]].
copy_
(
base_indices
)
self
.
_token_lora_indices
[:
base_indices
.
shape
[
0
]].
copy_
(
base_indices
)
...
...
vllm/lora/worker_manager.py
View file @
7f5edb59
...
@@ -73,6 +73,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
...
@@ -73,6 +73,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
device
=
self
.
device
,
lora_manager_cls
=
self
.
_manager_cls
,
lora_manager_cls
=
self
.
_manager_cls
,
)
)
self
.
_adapter_manager
=
lora_manager
self
.
_adapter_manager
=
lora_manager
...
@@ -176,6 +177,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
...
@@ -176,6 +177,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
max_num_seqs
=
self
.
max_num_seqs
,
max_num_seqs
=
self
.
max_num_seqs
,
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
device
=
self
.
device
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
)
)
self
.
_adapter_manager
=
lora_manager
self
.
_adapter_manager
=
lora_manager
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment