Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d5c5154f
Unverified
Commit
d5c5154f
authored
Dec 10, 2024
by
Aurick Qiao
Committed by
GitHub
Dec 11, 2024
Browse files
[Misc] LoRA + Chunked Prefill (#9057)
parent
9a939737
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
49 additions
and
20 deletions
+49
-20
tests/lora/test_chatglm3_tp.py
tests/lora/test_chatglm3_tp.py
+6
-3
tests/lora/test_gemma.py
tests/lora/test_gemma.py
+2
-1
tests/lora/test_llama_tp.py
tests/lora/test_llama_tp.py
+5
-1
tests/lora/test_long_context.py
tests/lora/test_long_context.py
+2
-1
tests/lora/test_minicpmv.py
tests/lora/test_minicpmv.py
+2
-1
tests/lora/test_minicpmv_tp.py
tests/lora/test_minicpmv_tp.py
+2
-0
tests/lora/test_mixtral.py
tests/lora/test_mixtral.py
+1
-0
tests/lora/test_phi.py
tests/lora/test_phi.py
+2
-1
tests/lora/test_quant_model.py
tests/lora/test_quant_model.py
+6
-3
vllm/config.py
vllm/config.py
+2
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+12
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+7
-5
No files found.
tests/lora/test_chatglm3_tp.py
View file @
d5c5154f
...
@@ -53,7 +53,8 @@ def test_chatglm3_lora(chatglm3_lora_files):
...
@@ -53,7 +53,8 @@ def test_chatglm3_lora(chatglm3_lora_files):
max_loras
=
4
,
max_loras
=
4
,
max_lora_rank
=
64
,
max_lora_rank
=
64
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
trust_remote_code
=
True
)
trust_remote_code
=
True
,
enable_chunked_prefill
=
True
)
output1
=
do_sample
(
llm
,
chatglm3_lora_files
,
lora_id
=
1
)
output1
=
do_sample
(
llm
,
chatglm3_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
...
@@ -73,7 +74,8 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
...
@@ -73,7 +74,8 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
max_lora_rank
=
64
,
max_lora_rank
=
64
,
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
fully_sharded_loras
=
False
)
fully_sharded_loras
=
False
,
enable_chunked_prefill
=
True
)
output1
=
do_sample
(
llm
,
chatglm3_lora_files
,
lora_id
=
1
)
output1
=
do_sample
(
llm
,
chatglm3_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
...
@@ -93,7 +95,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
...
@@ -93,7 +95,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
max_lora_rank
=
64
,
max_lora_rank
=
64
,
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
fully_sharded_loras
=
True
)
fully_sharded_loras
=
True
,
enable_chunked_prefill
=
True
)
output1
=
do_sample
(
llm
,
chatglm3_lora_files
,
lora_id
=
1
)
output1
=
do_sample
(
llm
,
chatglm3_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
for
i
in
range
(
len
(
EXPECTED_LORA_OUTPUT
)):
assert
output1
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
assert
output1
[
i
]
==
EXPECTED_LORA_OUTPUT
[
i
]
...
...
tests/lora/test_gemma.py
View file @
d5c5154f
...
@@ -37,7 +37,8 @@ def test_gemma_lora(gemma_lora_files):
...
@@ -37,7 +37,8 @@ def test_gemma_lora(gemma_lora_files):
llm
=
vllm
.
LLM
(
MODEL_PATH
,
llm
=
vllm
.
LLM
(
MODEL_PATH
,
max_model_len
=
1024
,
max_model_len
=
1024
,
enable_lora
=
True
,
enable_lora
=
True
,
max_loras
=
4
)
max_loras
=
4
,
enable_chunked_prefill
=
True
)
expected_lora_output
=
[
expected_lora_output
=
[
"more important than knowledge.
\n
Author: Albert Einstein
\n
"
,
"more important than knowledge.
\n
Author: Albert Einstein
\n
"
,
...
...
tests/lora/test_llama_tp.py
View file @
d5c5154f
...
@@ -78,7 +78,8 @@ def test_llama_lora(sql_lora_files):
...
@@ -78,7 +78,8 @@ def test_llama_lora(sql_lora_files):
enable_lora
=
True
,
enable_lora
=
True
,
max_num_seqs
=
16
,
max_num_seqs
=
16
,
max_loras
=
4
,
max_loras
=
4
,
tensor_parallel_size
=
1
)
tensor_parallel_size
=
1
,
enable_chunked_prefill
=
True
)
generate_and_test
(
llm
,
sql_lora_files
)
generate_and_test
(
llm
,
sql_lora_files
)
...
@@ -120,6 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
...
@@ -120,6 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_num_seqs
=
16
,
max_num_seqs
=
16
,
max_loras
=
4
,
max_loras
=
4
,
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
enable_chunked_prefill
=
True
,
)
)
generate_and_test
(
llm
,
sql_lora_files
)
generate_and_test
(
llm
,
sql_lora_files
)
...
@@ -135,6 +137,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
...
@@ -135,6 +137,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
max_loras
=
4
,
max_loras
=
4
,
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
fully_sharded_loras
=
True
,
fully_sharded_loras
=
True
,
enable_chunked_prefill
=
True
,
)
)
generate_and_test
(
llm
,
sql_lora_files
)
generate_and_test
(
llm
,
sql_lora_files
)
...
@@ -151,5 +154,6 @@ def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
...
@@ -151,5 +154,6 @@ def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
fully_sharded_loras
=
True
,
fully_sharded_loras
=
True
,
enable_lora_bias
=
True
,
enable_lora_bias
=
True
,
enable_chunked_prefill
=
True
,
)
)
generate_and_test
(
llm
,
sql_lora_files
)
generate_and_test
(
llm
,
sql_lora_files
)
tests/lora/test_long_context.py
View file @
d5c5154f
...
@@ -124,7 +124,8 @@ def lora_llm(long_context_infos):
...
@@ -124,7 +124,8 @@ def lora_llm(long_context_infos):
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
# FIXME enable async output processor
# FIXME enable async output processor
disable_async_output_proc
=
True
,
disable_async_output_proc
=
True
,
distributed_executor_backend
=
"mp"
)
distributed_executor_backend
=
"mp"
,
enable_chunked_prefill
=
True
)
yield
llm
yield
llm
del
llm
del
llm
...
...
tests/lora/test_minicpmv.py
View file @
d5c5154f
...
@@ -67,7 +67,8 @@ def test_minicpmv_lora(minicpmv_lora_files):
...
@@ -67,7 +67,8 @@ def test_minicpmv_lora(minicpmv_lora_files):
max_loras
=
4
,
max_loras
=
4
,
max_lora_rank
=
64
,
max_lora_rank
=
64
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
gpu_memory_utilization
=
0.97
# This model is pretty big for CI gpus
gpu_memory_utilization
=
0.97
,
# This model is pretty big for CI gpus
enable_chunked_prefill
=
True
,
)
)
output1
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
output1
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
...
...
tests/lora/test_minicpmv_tp.py
View file @
d5c5154f
...
@@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
...
@@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
fully_sharded_loras
=
fully_sharded
,
fully_sharded_loras
=
fully_sharded
,
enable_chunked_prefill
=
True
,
)
)
output_tp
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
output_tp
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
...
@@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
...
@@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
tensor_parallel_size
=
4
,
tensor_parallel_size
=
4
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
fully_sharded_loras
=
fully_sharded
,
fully_sharded_loras
=
fully_sharded
,
enable_chunked_prefill
=
True
,
)
)
output_tp
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
output_tp
=
do_sample
(
llm
,
minicpmv_lora_files
,
lora_id
=
1
)
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
for
i
in
range
(
len
(
EXPECTED_OUTPUT
)):
...
...
tests/lora/test_mixtral.py
View file @
d5c5154f
...
@@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
...
@@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
max_loras
=
4
,
max_loras
=
4
,
distributed_executor_backend
=
"ray"
,
distributed_executor_backend
=
"ray"
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
enable_chunked_prefill
=
True
,
)
)
expected_lora_output
=
[
expected_lora_output
=
[
...
...
tests/lora/test_phi.py
View file @
d5c5154f
...
@@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files):
...
@@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files):
max_model_len
=
1024
,
max_model_len
=
1024
,
enable_lora
=
True
,
enable_lora
=
True
,
max_loras
=
2
,
max_loras
=
2
,
enforce_eager
=
True
)
enforce_eager
=
True
,
enable_chunked_prefill
=
True
)
expected_lora_output
=
[
expected_lora_output
=
[
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;"
,
# noqa: E501
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;"
,
# noqa: E501
...
...
tests/lora/test_quant_model.py
View file @
d5c5154f
...
@@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
...
@@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
gpu_memory_utilization
=
0.2
,
#avoid OOM
gpu_memory_utilization
=
0.2
,
#avoid OOM
quantization
=
model
.
quantization
,
quantization
=
model
.
quantization
,
trust_remote_code
=
True
)
trust_remote_code
=
True
,
enable_chunked_prefill
=
True
)
if
model
.
quantization
is
None
:
if
model
.
quantization
is
None
:
expected_no_lora_output
=
[
expected_no_lora_output
=
[
...
@@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
...
@@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
gpu_memory_utilization
=
0.2
,
#avoid OOM
gpu_memory_utilization
=
0.2
,
#avoid OOM
quantization
=
model
.
quantization
,
quantization
=
model
.
quantization
,
trust_remote_code
=
True
)
trust_remote_code
=
True
,
enable_chunked_prefill
=
True
)
output_tp1
=
do_sample
(
llm_tp1
,
tinyllama_lora_files
,
lora_id
=
1
)
output_tp1
=
do_sample
(
llm_tp1
,
tinyllama_lora_files
,
lora_id
=
1
)
del
llm_tp1
del
llm_tp1
...
@@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
...
@@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
max_loras
=
4
,
max_loras
=
4
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
gpu_memory_utilization
=
0.2
,
#avoid OOM
gpu_memory_utilization
=
0.2
,
#avoid OOM
quantization
=
model
.
quantization
)
quantization
=
model
.
quantization
,
enable_chunked_prefill
=
True
)
output_tp2
=
do_sample
(
llm_tp2
,
tinyllama_lora_files
,
lora_id
=
1
)
output_tp2
=
do_sample
(
llm_tp2
,
tinyllama_lora_files
,
lora_id
=
1
)
del
llm_tp2
del
llm_tp2
...
...
vllm/config.py
View file @
d5c5154f
...
@@ -1698,7 +1698,8 @@ class LoRAConfig:
...
@@ -1698,7 +1698,8 @@ class LoRAConfig:
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
# If the feature combo become valid
if
scheduler_config
.
chunked_prefill_enabled
:
if
scheduler_config
.
chunked_prefill_enabled
:
raise
ValueError
(
"LoRA is not supported with chunked prefill yet."
)
logger
.
warning
(
"LoRA with chunked prefill is still experimental "
"and may be unstable."
)
@
dataclass
@
dataclass
...
...
vllm/core/scheduler.py
View file @
d5c5154f
...
@@ -166,9 +166,18 @@ class SchedulerOutputs:
...
@@ -166,9 +166,18 @@ class SchedulerOutputs:
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
and
not
self
.
blocks_to_swap_out
and
not
self
.
blocks_to_copy
)
def
_sort_by_lora_ids
(
self
):
def
_sort_by_lora_ids
(
self
):
self
.
scheduled_seq_groups
=
sorted
(
assert
0
<=
self
.
num_prefill_groups
<=
len
(
self
.
scheduled_seq_groups
)
self
.
scheduled_seq_groups
,
key
=
lambda
g
:
(
g
.
seq_group
.
lora_int_id
,
g
.
seq_group
.
request_id
))
def
key_fn
(
group
:
ScheduledSequenceGroup
):
key
=
(
group
.
seq_group
.
lora_int_id
,
group
.
seq_group
.
request_id
)
if
0
<
self
.
num_prefill_groups
<
len
(
self
.
scheduled_seq_groups
):
# Sort sequence groups so that all prefills come before all
# decodes as required by chunked prefill.
return
(
not
group
.
seq_group
.
is_prefill
(),
*
key
)
return
key
self
.
scheduled_seq_groups
=
sorted
(
self
.
scheduled_seq_groups
,
key
=
key_fn
)
@
property
@
property
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
def
lora_requests
(
self
)
->
Set
[
LoRARequest
]:
...
...
vllm/worker/model_runner.py
View file @
d5c5154f
...
@@ -622,11 +622,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -622,11 +622,13 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data
.
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
inter_data
.
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
query_len
=
inter_data
.
query_lens
[
seq_idx
]
query_len
=
inter_data
.
query_lens
[
seq_idx
]
inter_data
.
lora_index_mapping
.
append
([
lora_id
]
*
query_len
)
inter_data
.
lora_index_mapping
.
append
([
lora_id
]
*
query_len
)
inter_data
.
lora_prompt_mapping
.
append
(
sampling_params
=
seq_group_metadata
.
sampling_params
[
lora_id
]
*
if
sampling_params
and
sampling_params
.
prompt_logprobs
is
not
None
:
(
query_len
if
seq_group_metadata
.
sampling_params
inter_data
.
lora_prompt_mapping
.
append
([
lora_id
]
*
query_len
)
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
is
not
None
elif
not
self
.
chunked_prefill_enabled
or
seq_group_metadata
.
do_sample
:
else
1
))
inter_data
.
lora_prompt_mapping
.
append
([
lora_id
])
else
:
inter_data
.
lora_prompt_mapping
.
append
([])
def
_compute_prompt_adapter_input
(
def
_compute_prompt_adapter_input
(
self
,
inter_data
:
InterDataForSeqGroup
,
self
,
inter_data
:
InterDataForSeqGroup
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment