Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c75363fb
Unverified
Commit
c75363fb
authored
Aug 21, 2024
by
Nick Hill
Committed by
GitHub
Aug 21, 2024
Browse files
[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)
parent
dd3fa0e4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
21 deletions
+101
-21
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+88
-13
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+13
-8
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
c75363fb
import
asyncio
import
asyncio
import
os
import
os
from
asyncio
import
CancelledError
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
pytest
import
pytest_asyncio
import
torch
import
torch
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.outputs
import
RequestOutput
as
RealRequestOutput
from
..conftest
import
cleanup
from
..utils
import
wait_for_gpu_memory_to_clear
from
..utils
import
wait_for_gpu_memory_to_clear
...
@@ -118,15 +123,38 @@ async def test_new_requests_event():
...
@@ -118,15 +123,38 @@ async def test_new_requests_event():
os
.
environ
.
pop
(
"VLLM_ALLOW_ENGINE_USE_RAY"
)
os
.
environ
.
pop
(
"VLLM_ALLOW_ENGINE_USE_RAY"
)
def
test_asyncio_run
():
def
start_engine
():
wait_for_gpu_memory_to_clear
(
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
threshold_bytes
=
2
*
2
**
30
,
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
timeout_s
=
60
,
)
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
return
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
))
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
))
@
pytest_asyncio
.
fixture
(
scope
=
"module"
)
async
def
async_engine
():
engine
=
await
asyncio
.
get_event_loop
().
run_in_executor
(
executor
=
None
,
func
=
start_engine
)
try
:
yield
engine
finally
:
engine
.
shutdown_background_loop
()
del
engine
await
asyncio
.
sleep
(
0.1
)
cleanup
()
@
pytest
.
fixture
()
def
should_do_global_cleanup_after_test
(
request
)
->
bool
:
# So we can share the async engine fixture between these tests
return
False
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_asyncio_run
(
async_engine
):
async
def
run
(
prompt
:
str
):
async
def
run
(
prompt
:
str
):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
...
@@ -134,17 +162,64 @@ def test_asyncio_run():
...
@@ -134,17 +162,64 @@ def test_asyncio_run():
max_tokens
=
32
,
max_tokens
=
32
,
)
)
async
for
output
in
engine
.
generate
(
prompt
,
async
for
output
in
async_
engine
.
generate
(
prompt
,
sampling_params
,
sampling_params
,
request_id
=
prompt
):
request_id
=
prompt
):
final_output
=
output
final_output
=
output
return
final_output
return
final_output
async
def
generate
():
results
=
await
asyncio
.
gather
(
return
await
asyncio
.
gather
(
run
(
"test0"
),
run
(
"test0"
),
run
(
"test1"
),
run
(
"test1"
),
)
)
results
=
asyncio
.
run
(
generate
())
assert
len
(
results
)
==
2
assert
len
(
results
)
==
2
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_cancellation
(
async_engine
):
sampling_params
=
SamplingParams
(
temperature
=
0
,
min_tokens
=
10
,
max_tokens
=
10
,
)
i
=
0
with
pytest
.
raises
(
CancelledError
):
async
for
output
in
async_engine
.
generate
(
"test2"
,
sampling_params
,
request_id
=
"test2"
):
assert
not
output
.
finished
i
+=
1
if
i
==
5
:
await
async_engine
.
abort
(
"test2"
)
assert
i
==
5
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
async
def
test_delayed_generator
(
async_engine
):
sampling_params
=
SamplingParams
(
temperature
=
0
,
min_tokens
=
10
,
max_tokens
=
10
,
)
stream
=
async_engine
.
generate
(
"test3"
,
sampling_params
,
request_id
=
"test3"
)
i
=
0
final_output
:
Optional
[
RealRequestOutput
]
=
None
async
for
output
in
stream
:
final_output
=
output
if
i
==
0
:
# wait for generation to complete before consuming
# the remaining messages
await
asyncio
.
sleep
(
1
)
if
i
<
9
:
assert
not
output
.
finished
i
+=
1
assert
i
==
10
assert
final_output
is
not
None
assert
len
(
final_output
.
outputs
[
0
].
token_ids
)
==
10
assert
final_output
.
finished
vllm/engine/async_llm_engine.py
View file @
c75363fb
...
@@ -2,8 +2,8 @@ import asyncio
...
@@ -2,8 +2,8 @@ import asyncio
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
import
torch
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
...
@@ -85,9 +85,8 @@ class AsyncStream:
...
@@ -85,9 +85,8 @@ class AsyncStream:
def
put
(
self
,
item
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
,
def
put
(
self
,
item
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
,
Exception
])
->
None
:
Exception
])
->
None
:
if
self
.
_finished
:
if
not
self
.
_finished
:
return
self
.
_queue
.
put_nowait
(
item
)
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
def
finish
(
self
,
self
,
...
@@ -96,7 +95,7 @@ class AsyncStream:
...
@@ -96,7 +95,7 @@ class AsyncStream:
if
not
self
.
_finished
:
if
not
self
.
_finished
:
self
.
_finished
=
True
self
.
_finished
=
True
self
.
_queue
.
put_nowait
(
self
.
_queue
.
put_nowait
(
exception
if
exception
is
not
None
else
STOP_ITERATION
)
exception
if
self
.
_is_raisable
(
exception
)
else
STOP_ITERATION
)
@
property
@
property
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
...
@@ -106,9 +105,9 @@ class AsyncStream:
...
@@ -106,9 +105,9 @@ class AsyncStream:
self
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
try
:
try
:
while
not
self
.
_finished
:
while
True
:
result
=
await
self
.
_queue
.
get
()
result
=
await
self
.
_queue
.
get
()
if
isinstance
(
result
,
Exception
):
if
self
.
_is_raisable
(
result
):
if
result
==
STOP_ITERATION
:
if
result
==
STOP_ITERATION
:
return
return
raise
result
raise
result
...
@@ -117,6 +116,12 @@ class AsyncStream:
...
@@ -117,6 +116,12 @@ class AsyncStream:
self
.
_cancel
(
self
.
request_id
)
self
.
_cancel
(
self
.
request_id
)
raise
asyncio
.
CancelledError
from
None
raise
asyncio
.
CancelledError
from
None
@
staticmethod
def
_is_raisable
(
value
:
Any
):
return
isinstance
(
value
,
BaseException
)
or
\
(
isinstance
(
value
,
type
)
and
\
issubclass
(
value
,
BaseException
))
class
RequestTracker
:
class
RequestTracker
:
"""Synchronous abstraction for tracking requests."""
"""Synchronous abstraction for tracking requests."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment