Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4bb53e2d
Unverified
Commit
4bb53e2d
authored
May 01, 2024
by
leiwen83
Committed by
GitHub
Apr 30, 2024
Browse files
[BugFix] fix num_lookahead_slots missing in async executor (#4165)
Co-authored-by:
Lei Wen
<
wenlei03@qiyi.com
>
parent
26f2fb51
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
163 additions
and
19 deletions
+163
-19
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+123
-2
tests/spec_decode/e2e/test_compatibility.py
tests/spec_decode/e2e/test_compatibility.py
+11
-4
tests/spec_decode/e2e/test_correctness.py
tests/spec_decode/e2e/test_correctness.py
+16
-9
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+4
-2
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+3
-1
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+1
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+3
-1
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+1
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+1
-0
No files found.
tests/spec_decode/e2e/conftest.py
View file @
4bb53e2d
from
typing
import
List
,
Tuple
import
asyncio
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
pytest
import
ray
from
tests.conftest
import
cleanup
from
vllm
import
LLM
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
random_uuid
class
AsyncLLM
:
"""AsyncLLM
Note: Current LLM class in vllm don't support async mode, for test purpose,
we implement async one in here. Maybe we could move to
vllm/entrypoints/llm.py in future.
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
to make to work in async mode.
"""
def
__init__
(
self
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer_mode
:
str
=
"auto"
,
skip_tokenizer_init
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
seed
:
int
=
0
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
**
kwargs
,
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
self
.
engine_args
=
AsyncEngineArgs
(
model
=
model
,
tokenizer
=
tokenizer
,
tokenizer_mode
=
tokenizer_mode
,
skip_tokenizer_init
=
skip_tokenizer_init
,
trust_remote_code
=
trust_remote_code
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
quantization
=
quantization
,
revision
=
revision
,
tokenizer_revision
=
tokenizer_revision
,
seed
=
seed
,
gpu_memory_utilization
=
gpu_memory_utilization
,
swap_space
=
swap_space
,
enforce_eager
=
enforce_eager
,
max_context_len_to_capture
=
max_context_len_to_capture
,
engine_use_ray
=
True
,
disable_custom_all_reduce
=
disable_custom_all_reduce
,
**
kwargs
,
)
self
.
request_counter
=
Counter
()
def
generate
(
self
,
prompts
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
prompt_token_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
)
->
List
[
RequestOutput
]:
llm_engine
=
AsyncLLMEngine
.
from_engine_args
(
self
.
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
if
prompts
is
None
:
raise
ValueError
(
"prompts must be provided."
)
if
isinstance
(
prompts
,
str
):
# Convert a single prompt to a list.
prompts
=
[
prompts
]
if
prompts
is
not
None
:
num_requests
=
len
(
prompts
)
if
sampling_params
is
None
:
# Use default sampling params.
sampling_params
=
SamplingParams
()
elif
isinstance
(
sampling_params
,
list
)
and
len
(
sampling_params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and "
"sampling_params must be the same."
)
async
def
get_output
(
prompt
,
sampling_param
)
->
str
:
request_id
=
random_uuid
()
results_generator
=
llm_engine
.
generate
(
prompt
,
sampling_param
,
request_id
)
final_output
=
None
async
for
request_output
in
results_generator
:
final_output
=
request_output
return
final_output
outputs
=
[]
try
:
for
i
in
range
(
num_requests
):
prompt
=
prompts
[
i
]
if
prompts
is
not
None
else
None
res
=
asyncio
.
run
(
get_output
(
prompt
,
sampling_params
))
outputs
.
append
(
res
)
finally
:
ray
.
shutdown
()
return
outputs
@
pytest
.
fixture
...
...
@@ -36,8 +153,12 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
def
generator_inner
():
print
(
f
'Creating
{
baseline_or_test
=
}
LLM for
{
test_name
=
}
.
{
kwargs
=
}
'
)
llm
=
LLM
(
**
kwargs
)
use_async
=
False
if
"use_async"
in
kwargs
:
use_async
=
kwargs
.
pop
(
"use_async"
)
llm
=
AsyncLLM
(
**
kwargs
)
if
use_async
else
LLM
(
**
kwargs
)
set_random_seed
(
seed
)
yield
llm
...
...
tests/spec_decode/e2e/test_compatibility.py
View file @
4bb53e2d
...
...
@@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
temperature
=
temperature
,
)
with
pytest
.
raises
(
AssertionError
,
try
:
with
pytest
.
raises
(
AssertionError
,
match
=
"Speculative decoding not yet supported for "
):
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
finally
:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import
ray
ray
.
shutdown
()
@
pytest
.
mark
.
parametrize
(
...
...
tests/spec_decode/e2e/test_correctness.py
View file @
4bb53e2d
...
...
@@ -40,7 +40,8 @@ from .conftest import get_output_from_llm_generator
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
[
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model"
:
"JackFram/llama-68m"
,
...
...
@@ -49,8 +50,14 @@ from .conftest import get_output_from_llm_generator
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
"use_v2_block_manager"
:
True
,
# whether use AsyncLLM engine
"use_async"
:
async_mode
,
}
# Try both async and sync engine execution
for
async_mode
in
[
True
,
False
]
])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
...
...
vllm/engine/async_llm_engine.py
View file @
4bb53e2d
...
...
@@ -211,9 +211,11 @@ class _AsyncLLMEngine(LLMEngine):
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
seq_group_metadata_list
,
scheduler_outputs
.
blocks_to_swap_in
,
seq_group_metadata_list
,
scheduler_outputs
.
blocks_to_swap_in
,
scheduler_outputs
.
blocks_to_swap_out
,
scheduler_outputs
.
blocks_to_copy
)
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
)
else
:
output
=
[]
...
...
vllm/executor/cpu_executor.py
View file @
4bb53e2d
...
...
@@ -109,12 +109,14 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
)
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
)
return
output
async
def
check_health_async
(
self
)
->
None
:
...
...
vllm/executor/executor_base.py
View file @
4bb53e2d
...
...
@@ -112,6 +112,7 @@ class ExecutorAsyncBase(ExecutorBase):
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
"""Executes one model step on the given sequences."""
raise
NotImplementedError
...
...
vllm/executor/gpu_executor.py
View file @
4bb53e2d
...
...
@@ -163,10 +163,12 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
)
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
num_lookahead_slots
)
return
output
vllm/executor/neuron_executor.py
View file @
4bb53e2d
...
...
@@ -84,6 +84,7 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
,
)
->
List
[
SamplerOutput
]:
output
=
await
make_async
(
self
.
driver_worker
.
execute_model
)(
seq_group_metadata_list
=
seq_group_metadata_list
,
)
...
...
vllm/executor/ray_gpu_executor.py
View file @
4bb53e2d
...
...
@@ -196,6 +196,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
"blocks_to_swap_in"
:
blocks_to_swap_in
,
"blocks_to_swap_out"
:
blocks_to_swap_out
,
"blocks_to_copy"
:
blocks_to_copy
,
"num_lookahead_slots"
:
num_lookahead_slots
,
},
use_ray_compiled_dag
=
USE_RAY_COMPILED_DAG
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment