Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
78687504
Unverified
Commit
78687504
authored
Jun 19, 2024
by
zifeitong
Committed by
GitHub
Jun 19, 2024
Browse files
[Bugfix] AsyncLLMEngine hangs with asyncio.run (#5654)
parent
d571ca01
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
271 additions
and
47 deletions
+271
-47
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+37
-1
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+1
-42
tests/utils.py
tests/utils.py
+41
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+3
-2
vllm/engine/async_timeout.py
vllm/engine/async_timeout.py
+189
-0
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
78687504
...
...
@@ -2,8 +2,12 @@ import asyncio
from
dataclasses
import
dataclass
import
pytest
import
torch
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm
import
SamplingParams
from
vllm.engine.async_llm_engine
import
AsyncEngineArgs
,
AsyncLLMEngine
from
..utils
import
wait_for_gpu_memory_to_clear
@
dataclass
...
...
@@ -94,3 +98,35 @@ async def test_new_requests_event():
assert
engine
.
get_model_config
()
is
not
None
assert
engine
.
get_tokenizer
()
is
not
None
assert
engine
.
get_decoding_config
()
is
not
None
def
test_asyncio_run
():
wait_for_gpu_memory_to_clear
(
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
threshold_bytes
=
2
*
2
**
30
,
timeout_s
=
60
,
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
model
=
"facebook/opt-125m"
))
async
def
run
(
prompt
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
,
)
async
for
output
in
engine
.
generate
(
prompt
,
sampling_params
,
request_id
=
prompt
):
final_output
=
output
return
final_output
async
def
generate
():
return
await
asyncio
.
gather
(
run
(
"test0"
),
run
(
"test1"
),
)
results
=
asyncio
.
run
(
generate
())
assert
len
(
results
)
==
2
tests/spec_decode/e2e/conftest.py
View file @
78687504
import
asyncio
import
time
from
itertools
import
cycle
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
...
...
@@ -7,12 +6,6 @@ import pytest
import
ray
import
torch
from
vllm.utils
import
is_hip
if
(
not
is_hip
()):
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
nvmlInit
)
from
vllm
import
LLM
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
...
@@ -26,6 +19,7 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.utils
import
Counter
,
random_uuid
from
...conftest
import
cleanup
from
...utils
import
wait_for_gpu_memory_to_clear
class
AsyncLLM
:
...
...
@@ -291,38 +285,3 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
def
wait_for_gpu_memory_to_clear
(
devices
:
List
[
int
],
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit
()
start_time
=
time
.
time
()
while
True
:
output
:
Dict
[
int
,
str
]
=
{}
output_raw
:
Dict
[
int
,
float
]
=
{}
for
device
in
devices
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
time
.
sleep
(
5
)
tests/utils.py
View file @
78687504
...
...
@@ -4,7 +4,7 @@ import sys
import
time
import
warnings
from
contextlib
import
contextmanager
from
typing
import
List
from
typing
import
Dict
,
List
import
openai
import
ray
...
...
@@ -13,7 +13,11 @@ import requests
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
,
is_hip
if
(
not
is_hip
()):
from
pynvml
import
(
nvmlDeviceGetHandleByIndex
,
nvmlDeviceGetMemoryInfo
,
nvmlInit
)
# Path to root of repository so that utilities can be imported by ray workers
VLLM_PATH
=
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
os
.
pardir
,
os
.
pardir
))
...
...
@@ -154,3 +158,38 @@ def error_on_warning():
warnings
.
simplefilter
(
"error"
)
yield
def
wait_for_gpu_memory_to_clear
(
devices
:
List
[
int
],
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit
()
start_time
=
time
.
time
()
while
True
:
output
:
Dict
[
int
,
str
]
=
{}
output_raw
:
Dict
[
int
,
float
]
=
{}
for
device
in
devices
:
dev_handle
=
nvmlDeviceGetHandleByIndex
(
device
)
mem_info
=
nvmlDeviceGetMemoryInfo
(
dev_handle
)
gb_used
=
mem_info
.
used
/
2
**
30
output_raw
[
device
]
=
gb_used
output
[
device
]
=
f
'
{
gb_used
:.
02
f
}
'
print
(
'gpu memory used (GB): '
,
end
=
''
)
for
k
,
v
in
output
.
items
():
print
(
f
'
{
k
}
=
{
v
}
; '
,
end
=
''
)
print
(
''
)
dur_s
=
time
.
time
()
-
start_time
if
all
(
v
<=
(
threshold_bytes
/
2
**
30
)
for
v
in
output_raw
.
values
()):
print
(
f
'Done waiting for free GPU memory on devices
{
devices
=
}
'
f
'(
{
threshold_bytes
/
2
**
30
=
}
)
{
dur_s
=
:.
02
f
}
'
)
break
if
dur_s
>=
timeout_s
:
raise
ValueError
(
f
'Memory of devices
{
devices
=
}
not free after '
f
'
{
dur_s
=
:.
02
f
}
(
{
threshold_bytes
/
2
**
30
=
}
)'
)
time
.
sleep
(
5
)
vllm/engine/async_llm_engine.py
View file @
78687504
...
...
@@ -10,6 +10,7 @@ import vllm.envs as envs
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
...
...
@@ -545,8 +546,8 @@ class AsyncLLMEngine:
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try
:
h
as
_requests_in_progress
=
awa
it
asyncio
.
wait_for
(
self
.
engine_step
(),
ENGINE_ITERATION_TIMEOUT_S
)
as
ync
w
it
h
asyncio
_timeout
(
ENGINE_ITERATION_TIMEOUT_S
):
has_requests_in_progress
=
await
self
.
engine_step
(
)
except
asyncio
.
TimeoutError
as
exc
:
logger
.
error
(
"Engine iteration timed out. This should never happen!"
)
...
...
vllm/engine/async_timeout.py
0 → 100644
View file @
78687504
# Workaround for https://github.com/python/cpython/issues/86296
#
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
# Licensed under the Apache License (Apache-2.0)
import
asyncio
import
enum
import
sys
import
warnings
from
types
import
TracebackType
from
typing
import
Any
,
Optional
,
Type
if
sys
.
version_info
[:
2
]
>=
(
3
,
11
):
from
asyncio
import
timeout
as
asyncio_timeout
else
:
def
asyncio_timeout
(
delay
:
Optional
[
float
])
->
"Timeout"
:
"""timeout context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> async with timeout(0.001):
... async with aiohttp.get('https://github.com') as r:
... await r.text()
delay - value in seconds or None to disable timeout logic
"""
loop
=
asyncio
.
get_running_loop
()
deadline
=
loop
.
time
()
+
delay
if
delay
is
not
None
else
None
return
Timeout
(
deadline
,
loop
)
class
_State
(
enum
.
Enum
):
INIT
=
"INIT"
ENTER
=
"ENTER"
TIMEOUT
=
"TIMEOUT"
EXIT
=
"EXIT"
class
Timeout
:
# Internal class, please don't instantiate it directly
# Use timeout() and timeout_at() public factories instead.
#
# Implementation note: `async with timeout()` is preferred
# over `with timeout()`.
# While technically the Timeout class implementation
# doesn't need to be async at all,
# the `async with` statement explicitly points that
# the context manager should be used from async function context.
#
# This design allows to avoid many silly misusages.
#
# TimeoutError is raised immediately when scheduled
# if the deadline is passed.
# The purpose is to time out as soon as possible
# without waiting for the next await expression.
__slots__
=
(
"_deadline"
,
"_loop"
,
"_state"
,
"_timeout_handler"
)
def
__init__
(
self
,
deadline
:
Optional
[
float
],
loop
:
asyncio
.
AbstractEventLoop
)
->
None
:
self
.
_loop
=
loop
self
.
_state
=
_State
.
INIT
self
.
_timeout_handler
=
None
# type: Optional[asyncio.Handle]
if
deadline
is
None
:
self
.
_deadline
=
None
# type: Optional[float]
else
:
self
.
update
(
deadline
)
def
__enter__
(
self
)
->
"Timeout"
:
warnings
.
warn
(
"with timeout() is deprecated, use async with timeout()"
,
DeprecationWarning
,
stacklevel
=
2
,
)
self
.
_do_enter
()
return
self
def
__exit__
(
self
,
exc_type
:
Optional
[
Type
[
BaseException
]],
exc_val
:
Optional
[
BaseException
],
exc_tb
:
Optional
[
TracebackType
],
)
->
Optional
[
bool
]:
self
.
_do_exit
(
exc_type
)
return
None
async
def
__aenter__
(
self
)
->
"Timeout"
:
self
.
_do_enter
()
return
self
async
def
__aexit__
(
self
,
exc_type
:
Optional
[
Type
[
BaseException
]],
exc_val
:
Optional
[
BaseException
],
exc_tb
:
Optional
[
TracebackType
],
)
->
Optional
[
bool
]:
self
.
_do_exit
(
exc_type
)
return
None
@
property
def
expired
(
self
)
->
bool
:
"""Is timeout expired during execution?"""
return
self
.
_state
==
_State
.
TIMEOUT
@
property
def
deadline
(
self
)
->
Optional
[
float
]:
return
self
.
_deadline
def
reject
(
self
)
->
None
:
"""Reject scheduled timeout if any."""
# cancel is maybe better name but
# task.cancel() raises CancelledError in asyncio world.
if
self
.
_state
not
in
(
_State
.
INIT
,
_State
.
ENTER
):
raise
RuntimeError
(
f
"invalid state
{
self
.
_state
.
value
}
"
)
self
.
_reject
()
def
_reject
(
self
)
->
None
:
if
self
.
_timeout_handler
is
not
None
:
self
.
_timeout_handler
.
cancel
()
self
.
_timeout_handler
=
None
def
shift
(
self
,
delay
:
float
)
->
None
:
"""Advance timeout on delay seconds.
The delay can be negative.
Raise RuntimeError if shift is called when deadline is not scheduled
"""
deadline
=
self
.
_deadline
if
deadline
is
None
:
raise
RuntimeError
(
"cannot shift timeout if deadline is not scheduled"
)
self
.
update
(
deadline
+
delay
)
def
update
(
self
,
deadline
:
float
)
->
None
:
"""Set deadline to absolute value.
deadline argument points on the time in the same clock system
as loop.time().
If new deadline is in the past the timeout is raised immediately.
Please note: it is not POSIX time but a time with
undefined starting base, e.g. the time of the system power on.
"""
if
self
.
_state
==
_State
.
EXIT
:
raise
RuntimeError
(
"cannot reschedule after exit from context manager"
)
if
self
.
_state
==
_State
.
TIMEOUT
:
raise
RuntimeError
(
"cannot reschedule expired timeout"
)
if
self
.
_timeout_handler
is
not
None
:
self
.
_timeout_handler
.
cancel
()
self
.
_deadline
=
deadline
if
self
.
_state
!=
_State
.
INIT
:
self
.
_reschedule
()
def
_reschedule
(
self
)
->
None
:
assert
self
.
_state
==
_State
.
ENTER
deadline
=
self
.
_deadline
if
deadline
is
None
:
return
now
=
self
.
_loop
.
time
()
if
self
.
_timeout_handler
is
not
None
:
self
.
_timeout_handler
.
cancel
()
task
=
asyncio
.
current_task
()
if
deadline
<=
now
:
self
.
_timeout_handler
=
self
.
_loop
.
call_soon
(
self
.
_on_timeout
,
task
)
else
:
self
.
_timeout_handler
=
self
.
_loop
.
call_at
(
deadline
,
self
.
_on_timeout
,
task
)
def
_do_enter
(
self
)
->
None
:
if
self
.
_state
!=
_State
.
INIT
:
raise
RuntimeError
(
f
"invalid state
{
self
.
_state
.
value
}
"
)
self
.
_state
=
_State
.
ENTER
self
.
_reschedule
()
def
_do_exit
(
self
,
exc_type
:
Optional
[
Type
[
BaseException
]])
->
None
:
if
exc_type
is
asyncio
.
CancelledError
and
\
self
.
_state
==
_State
.
TIMEOUT
:
self
.
_timeout_handler
=
None
raise
asyncio
.
TimeoutError
# timeout has not expired
self
.
_state
=
_State
.
EXIT
self
.
_reject
()
return
None
def
_on_timeout
(
self
,
task
:
"Optional[asyncio.Task[Any]]"
)
->
None
:
if
task
:
task
.
cancel
()
self
.
_state
=
_State
.
TIMEOUT
# drop the reference early
self
.
_timeout_handler
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment