Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
67882dbb
Unverified
Commit
67882dbb
authored
Jun 25, 2024
by
Antoni Baum
Committed by
GitHub
Jun 25, 2024
Browse files
[Core] Add fault tolerance for `RayTokenizerGroupPool` (#5748)
parent
7b993143
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
195 additions
and
24 deletions
+195
-24
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+99
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+2
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-0
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+4
-0
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+88
-24
No files found.
tests/tokenization/test_tokenizer_group.py
View file @
67882dbb
import
asyncio
import
os
import
sys
from
typing
import
List
,
Optional
from
unittest.mock
import
patch
import
pytest
...
...
@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
max_num_seqs
=
1
,
max_input_length
=
None
)
tokenizer_pool
.
ping
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_ray_pool_fault_tolerance
(
tokenizer_group_type
):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""
class
FailingTokenizerGroup
(
TokenizerGroup
):
def
__init__
(
self
,
*
args
,
fail_at
:
Optional
[
List
[
int
]]
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
i
=
0
self
.
fail_at
=
fail_at
or
[]
def
encode
(
self
,
*
args
,
**
kwargs
):
self
.
i
+=
1
if
self
.
i
in
self
.
fail_at
:
sys
.
exit
(
1
)
return
super
().
encode
(
*
args
,
**
kwargs
)
class
FailingRayTokenizerGroupPool
(
RayTokenizerGroupPool
):
_worker_cls
=
FailingTokenizerGroup
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at
[
0
]
=
1000
# We should recover successfully.
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# Check that we have a new actor
assert
len
(
tokenizer_group_pool
.
tokenizer_actors
)
==
len
(
tokenizer_actors
)
assert
tokenizer_group_pool
.
tokenizer_actors
!=
tokenizer_actors
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
# We should fail after re-initialization.
with
pytest
.
raises
(
RuntimeError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# check_health should raise the same thing
with
pytest
.
raises
(
RuntimeError
):
tokenizer_group_pool
.
check_health
()
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at
=
[]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
2
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Prompt too long error
with
pytest
.
raises
(
ValueError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
*
100
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# Actors should stay the same.
assert
tokenizer_group_pool
.
tokenizer_actors
==
tokenizer_actors
vllm/engine/async_llm_engine.py
View file @
67882dbb
...
...
@@ -310,6 +310,8 @@ class _AsyncLLMEngine(LLMEngine):
)
async
def
check_health_async
(
self
)
->
None
:
if
self
.
tokenizer
:
self
.
tokenizer
.
check_health
()
self
.
model_executor
.
check_health
()
...
...
vllm/engine/llm_engine.py
View file @
67882dbb
...
...
@@ -1013,6 +1013,8 @@ class LLMEngine:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
check_health
(
self
)
->
None
:
if
self
.
tokenizer
:
self
.
tokenizer
.
check_health
()
self
.
model_executor
.
check_health
()
def
is_tracing_enabled
(
self
)
->
bool
:
...
...
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
67882dbb
...
...
@@ -53,3 +53,7 @@ class BaseTokenizerGroup(ABC):
)
->
"PreTrainedTokenizer"
:
"""Get a tokenizer for a LoRA request."""
pass
def
check_health
(
self
):
"""Raise exception if the tokenizer group is unhealthy."""
return
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
67882dbb
...
...
@@ -2,17 +2,21 @@ import asyncio
import
os
from
typing
import
List
,
Optional
from
ray.exceptions
import
ActorDiedError
from
ray.util.scheduling_strategies
import
NodeAffinitySchedulingStrategy
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
TokenizerGroup
)
logger
=
init_logger
(
__name__
)
class
RayTokenizerGroupPool
(
BaseTokenizerGroup
):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
...
...
@@ -46,24 +50,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
ray_actor_options
:
dict
,
**
tokenizer_config
):
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
self
.
_tokenizer_config
=
{
"tokenizer_id"
:
tokenizer_id
,
"enable_lora"
:
enable_lora
,
"max_num_seqs"
:
max_num_seqs
,
"max_input_length"
:
max_input_length
,
**
tokenizer_config
}
self
.
_local_tokenizer_group
=
self
.
_worker_cls
(
tokenizer_id
=
tokenizer_id
,
enable_lora
=
enable_lora
,
max_num_seqs
=
max_num_seqs
,
max_input_length
=
max_input_length
,
**
tokenizer_config
,
)
ray_tokenizer_group_cls
=
ray
.
remote
(
**
self
.
_tokenizer_config
,
)
self
.
_ray_tokenizer_group_cls
=
ray
.
remote
(
self
.
_worker_cls
).
options
(
**
ray_actor_options
)
self
.
tokenizer_actors
=
[
ray_tokenizer_group_cls
.
remote
(
tokenizer_id
,
enable_lora
,
max_num_seqs
,
max_input_length
,
**
tokenizer_config
)
for
_
in
range
(
num_actors
)
]
self
.
tokenizer_actors
=
[
self
.
_init_actor
()
for
_
in
range
(
num_actors
)]
self
.
_idle_actors
:
Optional
[
asyncio
.
Queue
]
=
None
# If set, actor is unhealthy. Will reraise on the next
# check_health call.
self
.
_exception
:
Optional
[
ActorDiedError
]
=
None
def
_init_actor
(
self
)
->
ray
.
ObjectRef
:
return
self
.
_ray_tokenizer_group_cls
.
remote
(
**
self
.
_tokenizer_config
)
@
property
def
pool_size
(
self
)
->
int
:
return
len
(
self
.
tokenizer_actors
)
...
...
@@ -78,6 +86,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
for
actor
in
self
.
tokenizer_actors
:
self
.
_idle_actors
.
put_nowait
(
actor
)
def
_finalize_encode
(
self
,
actor
:
ray
.
ObjectRef
,
original_actor
:
ray
.
ObjectRef
,
actor_is_alive
:
bool
):
assert
self
.
_idle_actors
is
not
None
# Cleanup the dead actor.
if
not
actor_is_alive
or
original_actor
is
not
actor
:
self
.
tokenizer_actors
.
remove
(
original_actor
)
if
actor_is_alive
:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
# Add back the new actor.
if
original_actor
is
not
actor
:
self
.
tokenizer_actors
.
append
(
actor
)
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
...
...
@@ -88,23 +112,41 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
The actor is then put back in the queue for future use.
This is blocking.
"""
self
.
check_health
()
self
.
_ensure_queue_initialized
()
assert
self
.
_idle_actors
is
not
None
if
self
.
_idle_actors
.
empty
():
raise
RuntimeError
(
"No idle actors available."
)
actor
=
self
.
_idle_actors
.
get_nowait
()
actor_is_alive
=
True
original_actor
=
actor
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
))
except
ActorDiedError
as
e
:
# If the actor is dead, we first try to reinitialize it.
logger
.
warning
(
"%s died with ActorDiedError, reinitializing."
,
actor
,
exc_info
=
e
)
actor
=
self
.
_init_actor
()
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
))
except
ActorDiedError
as
e
:
logger
.
error
(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy."
,
actor
)
actor_is_alive
=
False
if
not
self
.
_exception
:
self
.
_exception
=
e
self
.
check_health
()
finally
:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
self
.
_finalize_encode
(
actor
,
original_actor
,
actor_is_alive
)
return
ret
async
def
encode_async
(
...
...
@@ -120,20 +162,37 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
The actor is then put back in the queue for future use.
This is non-blocking.
"""
self
.
check_health
()
self
.
_ensure_queue_initialized
()
assert
self
.
_idle_actors
is
not
None
actor
=
await
self
.
_idle_actors
.
get
()
actor_is_alive
=
True
original_actor
=
actor
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
except
ActorDiedError
as
e
:
# If the actor is dead, we first try to reinitialize it.
logger
.
warning
(
"%s died with ActorDiedError, reinitializing."
,
actor
,
exc_info
=
e
)
actor
=
self
.
_init_actor
()
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
except
ActorDiedError
as
e
:
logger
.
error
(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy."
,
actor
)
actor_is_alive
=
False
if
not
self
.
_exception
:
self
.
_exception
=
e
self
.
check_health
()
finally
:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
self
.
_finalize_encode
(
actor
,
original_actor
,
actor_is_alive
)
return
ret
def
get_max_input_len
(
self
,
...
...
@@ -155,6 +214,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return
await
self
.
_local_tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
def
check_health
(
self
):
if
self
.
_exception
:
raise
RuntimeError
(
"TokenizerGroupPool is unhealthy."
)
from
self
.
_exception
def
_carry_over_env_vars_to_runtime_env
(
runtime_env
:
dict
)
->
None
:
"""Copy over all current process environment variables to the runtime_env.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment