Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c75363fb
Unverified
Commit
c75363fb
authored
Aug 21, 2024
by
Nick Hill
Committed by
GitHub
Aug 21, 2024
Browse files
[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)
parent
dd3fa0e4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
21 deletions
+101
-21
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+88
-13
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+13
-8
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
c75363fb
import
asyncio
import
os
from
asyncio
import
CancelledError
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
pytest_asyncio
import
torch
from
vllm
import
SamplingParams
from
vllm.config
import
ParallelConfig
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.outputs
import
RequestOutput
as
RealRequestOutput
from
..conftest
import
cleanup
from
..utils
import
wait_for_gpu_memory_to_clear
...
...
@@ -118,15 +123,38 @@ async def test_new_requests_event():
os
.
environ
.
pop
(
"VLLM_ALLOW_ENGINE_USE_RAY"
)
def
test_asyncio_run
():
def
start_engine
():
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
))
return
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
))
@
pytest_asyncio
.
fixture
(
scope
=
"module"
)
async
def
async_engine
():
engine
=
await
asyncio
.
get_event_loop
().
run_in_executor
(
executor
=
None
,
func
=
start_engine
)
try
:
yield
engine
finally
:
engine
.
shutdown_background_loop
()
del
engine
await
asyncio
.
sleep
(
0.1
)
cleanup
()
@
pytest
.
fixture
()
def
should_do_global_cleanup_after_test
(
request
)
->
bool
:
# So we can share the async engine fixture between these tests
return
False
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_asyncio_run
(
async_engine
):
async
def
run
(
prompt
:
str
):
sampling_params
=
SamplingParams
(
...
...
@@ -134,17 +162,64 @@ def test_asyncio_run():
max_tokens
=
32
,
)
async
for
output
in
engine
.
generate
(
prompt
,
async
for
output
in
async_
engine
.
generate
(
prompt
,
sampling_params
,
request_id
=
prompt
):
final_output
=
output
return
final_output
async
def
generate
():
return
await
asyncio
.
gather
(
results
=
await
asyncio
.
gather
(
run
(
"test0"
),
run
(
"test1"
),
)
results
=
asyncio
.
run
(
generate
())
assert
len
(
results
)
==
2
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_cancellation
(
async_engine
):
sampling_params
=
SamplingParams
(
temperature
=
0
,
min_tokens
=
10
,
max_tokens
=
10
,
)
i
=
0
with
pytest
.
raises
(
CancelledError
):
async
for
output
in
async_engine
.
generate
(
"test2"
,
sampling_params
,
request_id
=
"test2"
):
assert
not
output
.
finished
i
+=
1
if
i
==
5
:
await
async_engine
.
abort
(
"test2"
)
assert
i
==
5
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_delayed_generator
(
async_engine
):
sampling_params
=
SamplingParams
(
temperature
=
0
,
min_tokens
=
10
,
max_tokens
=
10
,
)
stream
=
async_engine
.
generate
(
"test3"
,
sampling_params
,
request_id
=
"test3"
)
i
=
0
final_output
:
Optional
[
RealRequestOutput
]
=
None
async
for
output
in
stream
:
final_output
=
output
if
i
==
0
:
# wait for generation to complete before consuming
# the remaining messages
await
asyncio
.
sleep
(
1
)
if
i
<
9
:
assert
not
output
.
finished
i
+=
1
assert
i
==
10
assert
final_output
is
not
None
assert
len
(
final_output
.
outputs
[
0
].
token_ids
)
==
10
assert
final_output
.
finished
vllm/engine/async_llm_engine.py
View file @
c75363fb
...
...
@@ -2,8 +2,8 @@ import asyncio
import
time
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
from
typing_extensions
import
assert_never
...
...
@@ -85,8 +85,7 @@ class AsyncStream:
def
put
(
self
,
item
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
,
Exception
])
->
None
:
if
self
.
_finished
:
return
if
not
self
.
_finished
:
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
...
...
@@ -96,7 +95,7 @@ class AsyncStream:
if
not
self
.
_finished
:
self
.
_finished
=
True
self
.
_queue
.
put_nowait
(
exception
if
exception
is
not
None
else
STOP_ITERATION
)
exception
if
self
.
_is_raisable
(
exception
)
else
STOP_ITERATION
)
@
property
def
finished
(
self
)
->
bool
:
...
...
@@ -106,9 +105,9 @@ class AsyncStream:
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
try
:
while
not
self
.
_finished
:
while
True
:
result
=
await
self
.
_queue
.
get
()
if
isinstance
(
result
,
Exception
):
if
self
.
_is_raisable
(
result
):
if
result
==
STOP_ITERATION
:
return
raise
result
...
...
@@ -117,6 +116,12 @@ class AsyncStream:
self
.
_cancel
(
self
.
request_id
)
raise
asyncio
.
CancelledError
from
None
@
staticmethod
def
_is_raisable
(
value
:
Any
):
return
isinstance
(
value
,
BaseException
)
or
\
(
isinstance
(
value
,
type
)
and
\
issubclass
(
value
,
BaseException
))
class
RequestTracker
:
"""Synchronous abstraction for tracking requests."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment