Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1096717a
Unverified
Commit
1096717a
authored
Apr 12, 2024
by
Jee Li
Committed by
GitHub
Apr 11, 2024
Browse files
[Core] Support LoRA on quantized models (#4012)
parent
c2b4a1bc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
234 additions
and
26 deletions
+234
-26
tests/lora/conftest.py
tests/lora/conftest.py
+5
-0
tests/lora/test_quant_model.py
tests/lora/test_quant_model.py
+179
-0
vllm/config.py
vllm/config.py
+6
-3
vllm/lora/layers.py
vllm/lora/layers.py
+44
-23
No files found.
tests/lora/conftest.py
View file @
1096717a
...
@@ -143,6 +143,11 @@ def baichuan_lora_files():
...
@@ -143,6 +143,11 @@ def baichuan_lora_files():
return
snapshot_download
(
repo_id
=
"jeeejeee/baichuan7b-text2sql-spider"
)
return
snapshot_download
(
repo_id
=
"jeeejeee/baichuan7b-text2sql-spider"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
tinyllama_lora_files
():
return
snapshot_download
(
repo_id
=
"jashing/tinyllama-colorist-lora"
)
@
pytest
.
fixture
@
pytest
.
fixture
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
def
llama_2_7b_engine_extra_embeddings
()
->
nn
.
Module
:
cleanup
()
cleanup
()
...
...
tests/lora/test_quant_model.py
0 → 100644
View file @
1096717a
# Adapted from
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
from
dataclasses
import
dataclass
from
typing
import
List
import
pytest
import
vllm
from
vllm.lora.request
import
LoRARequest
from
.conftest
import
cleanup
@
dataclass
class
ModelWithQuantization
:
model_path
:
str
quantization
:
str
MODELS
:
List
[
ModelWithQuantization
]
=
[
ModelWithQuantization
(
model_path
=
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ"
,
quantization
=
"AWQ"
),
ModelWithQuantization
(
model_path
=
"TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ"
,
quantization
=
"GPTQ"
),
]
def
do_sample
(
llm
,
lora_path
:
str
,
lora_id
:
int
,
max_tokens
=
256
):
raw_prompts
=
[
"Give me an orange-ish brown color"
,
"Give me a neon pink color"
,
]
def
format_prompt_tuples
(
prompt
):
return
f
"<|im_start|>user
\n
{
prompt
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
prompts
=
[
format_prompt_tuples
(
p
)
for
p
in
raw_prompts
]
sampling_params
=
vllm
.
SamplingParams
(
temperature
=
0
,
max_tokens
=
max_tokens
,
stop
=
[
"<|im_end|>"
])
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
lora_request
=
LoRARequest
(
str
(
lora_id
),
lora_id
,
lora_path
)
if
lora_id
else
None
)
# Print the outputs.
generated_texts
=
[]
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_texts
.
append
(
generated_text
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
return
generated_texts
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
def
test_quant_model_lora
(
tinyllama_lora_files
,
model
,
tp_size
):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm
=
vllm
.
LLM
(
model
=
model
.
model_path
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
4
,
max_model_len
=
400
,
tensor_parallel_size
=
tp_size
,
quantization
=
model
.
quantization
,
trust_remote_code
=
True
)
if
model
.
quantization
is
None
:
expected_no_lora_output
=
[
"Here are some examples of orange-brown colors"
,
"I'm sorry, I don't have"
]
expected_lora_output
=
[
"#ff8050"
,
"#ff8080"
,
]
elif
model
.
quantization
==
"AWQ"
:
expected_no_lora_output
=
[
"I'm sorry, I don't understand"
,
"I'm sorry, I don't understand"
,
]
expected_lora_output
=
[
"#f07700: A v"
,
"#f00000: A v"
,
]
elif
model
.
quantization
==
"GPTQ"
:
expected_no_lora_output
=
[
"I'm sorry, I don't have"
,
"I'm sorry, I don't have"
,
]
expected_lora_output
=
[
"#f08800: This is"
,
"#f07788
\n
#"
,
]
def
expect_match
(
output
,
expected_output
):
# HACK: GPTQ lora outputs are just incredibly unstable.
# Assert that the outputs changed.
if
(
model
.
quantization
==
"GPTQ"
and
expected_output
is
expected_lora_output
):
assert
output
!=
expected_no_lora_output
for
i
,
o
in
enumerate
(
output
):
assert
o
.
startswith
(
'#'
),
f
"Expected example
{
i
}
to start with # but got
{
o
}
"
return
assert
output
==
expected_output
max_tokens
=
10
print
(
"lora adapter created"
)
output
=
do_sample
(
llm
,
tinyllama_lora_files
,
lora_id
=
0
,
max_tokens
=
max_tokens
)
expect_match
(
output
,
expected_no_lora_output
)
print
(
"lora 1"
)
output
=
do_sample
(
llm
,
tinyllama_lora_files
,
lora_id
=
1
,
max_tokens
=
max_tokens
)
expect_match
(
output
,
expected_lora_output
)
print
(
"no lora"
)
output
=
do_sample
(
llm
,
tinyllama_lora_files
,
lora_id
=
0
,
max_tokens
=
max_tokens
)
expect_match
(
output
,
expected_no_lora_output
)
print
(
"lora 2"
)
output
=
do_sample
(
llm
,
tinyllama_lora_files
,
lora_id
=
2
,
max_tokens
=
max_tokens
)
expect_match
(
output
,
expected_lora_output
)
print
(
"removing lora"
)
del
llm
cleanup
()
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
skip
(
"Requires multiple GPUs"
)
def
test_quant_model_tp_equality
(
tinyllama_lora_files
,
model
):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 2:
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
llm_tp1
=
vllm
.
LLM
(
model
=
model
.
model_path
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
4
,
tensor_parallel_size
=
1
,
quantization
=
model
.
quantization
,
trust_remote_code
=
True
)
output_tp1
=
do_sample
(
llm_tp1
,
tinyllama_lora_files
,
lora_id
=
1
)
del
llm_tp1
cleanup
()
llm_tp2
=
vllm
.
LLM
(
model
=
model
.
model_path
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_loras
=
4
,
tensor_parallel_size
=
2
,
quantization
=
model
.
quantization
)
output_tp2
=
do_sample
(
llm_tp2
,
tinyllama_lora_files
,
lora_id
=
1
)
del
llm_tp2
cleanup
()
assert
output_tp1
==
output_tp2
vllm/config.py
View file @
1096717a
...
@@ -822,9 +822,12 @@ class LoRAConfig:
...
@@ -822,9 +822,12 @@ class LoRAConfig:
self
.
lora_dtype
=
model_config
.
dtype
self
.
lora_dtype
=
model_config
.
dtype
elif
isinstance
(
self
.
lora_dtype
,
str
):
elif
isinstance
(
self
.
lora_dtype
,
str
):
self
.
lora_dtype
=
getattr
(
torch
,
self
.
lora_dtype
)
self
.
lora_dtype
=
getattr
(
torch
,
self
.
lora_dtype
)
if
model_config
.
quantization
is
not
None
:
if
model_config
.
quantization
and
model_config
.
quantization
not
in
[
raise
ValueError
(
"awq"
,
"gptq"
"LoRA is not supported with quantized models yet."
)
]:
# TODO support marlin and squeezellm
logger
.
warning
(
f
"
{
model_config
.
quantization
}
quantization is not "
"tested with LoRA yet."
)
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
def
verify_with_scheduler_config
(
self
,
scheduler_config
:
SchedulerConfig
):
if
scheduler_config
.
max_num_batched_tokens
>
65528
:
if
scheduler_config
.
max_num_batched_tokens
>
65528
:
...
...
vllm/lora/layers.py
View file @
1096717a
...
@@ -29,6 +29,19 @@ if TYPE_CHECKING:
...
@@ -29,6 +29,19 @@ if TYPE_CHECKING:
pass
pass
def
_get_lora_device
(
base_layer
:
nn
.
Module
)
->
torch
.
device
:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
if
hasattr
(
base_layer
,
"weight"
):
return
base_layer
.
weight
.
device
if
hasattr
(
base_layer
,
"linear_weights"
)
and
isinstance
(
base_layer
.
linear_weights
,
dict
):
values
=
list
(
base_layer
.
linear_weights
.
values
())
if
len
(
values
)
and
isinstance
(
values
[
0
],
torch
.
Tensor
):
return
values
[
0
].
device
raise
ValueError
(
f
"Unsupported base layer:
{
base_layer
}
"
)
def
_apply_lora
(
def
_apply_lora
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
...
@@ -302,6 +315,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -302,6 +315,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
super
().
__init__
()
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
base_layer
=
base_layer
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size
=
self
.
base_layer
.
input_size
self
.
output_size
=
self
.
base_layer
.
output_size_per_partition
self
.
device
=
_get_lora_device
(
self
.
base_layer
)
def
create_lora_weights
(
def
create_lora_weights
(
self
,
self
,
...
@@ -312,17 +328,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -312,17 +328,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras
,
max_loras
,
1
,
1
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
weight
.
shape
[
1
]
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
)
)
self
.
lora_b_stacked
=
torch
.
zeros
(
self
.
lora_b_stacked
=
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
self
.
base_layer
.
weight
.
shape
[
0
]
,
self
.
output_size
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
)
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -442,18 +458,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -442,18 +458,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras
,
max_loras
,
1
,
1
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
weight
.
shape
[
1
]
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
)
for
_
in
range
(
n_slices
))
)
for
_
in
range
(
n_slices
))
self
.
lora_b_stacked
=
tuple
(
self
.
lora_b_stacked
=
tuple
(
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
self
.
base_layer
.
weight
.
shape
[
0
]
//
2
,
self
.
output_size
//
2
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
)
for
_
in
range
(
n_slices
))
)
for
_
in
range
(
n_slices
))
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -619,25 +635,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -619,25 +635,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras
,
max_loras
,
1
,
1
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
weight
.
shape
[
1
]
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
),
),
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
weight
.
shape
[
1
]
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
),
),
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
weight
.
shape
[
1
]
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
),
),
)
)
self
.
lora_b_stacked
=
(
self
.
lora_b_stacked
=
(
...
@@ -647,7 +663,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -647,7 +663,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
q_proj_shard_size
,
self
.
q_proj_shard_size
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
),
),
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
...
@@ -655,7 +671,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -655,7 +671,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
kv_proj_shard_size
,
self
.
kv_proj_shard_size
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
),
),
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
...
@@ -663,7 +679,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -663,7 +679,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
kv_proj_shard_size
,
self
.
kv_proj_shard_size
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
),
),
)
)
...
@@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
def
__init__
(
self
,
base_layer
:
RowParallelLinear
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
base_layer
=
base_layer
self
.
input_size
=
self
.
base_layer
.
input_size_per_partition
self
.
output_size
=
self
.
base_layer
.
output_size
self
.
device
=
_get_lora_device
(
self
.
base_layer
)
def
create_lora_weights
(
def
create_lora_weights
(
self
,
self
,
...
@@ -777,20 +796,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -777,20 +796,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras
,
max_loras
,
1
,
1
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
self
.
base_layer
.
weight
.
shape
[
1
]
,
self
.
input_size
,
),
),
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
)
)
self
.
lora_b_stacked
=
torch
.
zeros
(
self
.
lora_b_stacked
=
torch
.
zeros
(
(
(
max_loras
,
max_loras
,
1
,
1
,
self
.
base_layer
.
weight
.
shape
[
0
]
,
self
.
output_size
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
),
),
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
base_layer
.
weight
.
device
,
device
=
self
.
device
,
)
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
...
@@ -809,7 +828,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -809,7 +828,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
if
self
.
base_layer
.
tp_size
>
1
:
if
self
.
base_layer
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
base_layer
.
weight
.
shape
[
1
]
shard_size
=
self
.
input_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:]
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:]
...
@@ -884,7 +903,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -884,7 +903,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@
property
@
property
def
weight
(
self
):
def
weight
(
self
):
return
self
.
base_layer
.
weight
return
self
.
base_layer
.
weight
if
hasattr
(
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
@
classmethod
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment