Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
67882dbb
Unverified
Commit
67882dbb
authored
Jun 25, 2024
by
Antoni Baum
Committed by
GitHub
Jun 25, 2024
Browse files
[Core] Add fault tolerance for `RayTokenizerGroupPool` (#5748)
parent
7b993143
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
195 additions
and
24 deletions
+195
-24
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+99
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+2
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-0
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+4
-0
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+88
-24
No files found.
tests/tokenization/test_tokenizer_group.py
View file @
67882dbb
import
asyncio
import
asyncio
import
os
import
os
import
sys
from
typing
import
List
,
Optional
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
...
@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
max_num_seqs
=
1
,
max_num_seqs
=
1
,
max_input_length
=
None
)
max_input_length
=
None
)
tokenizer_pool
.
ping
()
tokenizer_pool
.
ping
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_ray_pool_fault_tolerance
(
tokenizer_group_type
):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""
class
FailingTokenizerGroup
(
TokenizerGroup
):
def
__init__
(
self
,
*
args
,
fail_at
:
Optional
[
List
[
int
]]
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
i
=
0
self
.
fail_at
=
fail_at
or
[]
def
encode
(
self
,
*
args
,
**
kwargs
):
self
.
i
+=
1
if
self
.
i
in
self
.
fail_at
:
sys
.
exit
(
1
)
return
super
().
encode
(
*
args
,
**
kwargs
)
class
FailingRayTokenizerGroupPool
(
RayTokenizerGroupPool
):
_worker_cls
=
FailingTokenizerGroup
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at
[
0
]
=
1000
# We should recover successfully.
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# Check that we have a new actor
assert
len
(
tokenizer_group_pool
.
tokenizer_actors
)
==
len
(
tokenizer_actors
)
assert
tokenizer_group_pool
.
tokenizer_actors
!=
tokenizer_actors
# Fail at first iteration
fail_at
=
[
1
]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
fail_at
=
fail_at
)
# We should fail after re-initialization.
with
pytest
.
raises
(
RuntimeError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# check_health should raise the same thing
with
pytest
.
raises
(
RuntimeError
):
tokenizer_group_pool
.
check_health
()
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at
=
[]
tokenizer_group_pool
=
FailingRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
2
,
fail_at
=
fail_at
)
tokenizer_actors
=
tokenizer_group_pool
.
tokenizer_actors
.
copy
()
# Prompt too long error
with
pytest
.
raises
(
ValueError
):
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
*
100
,
lora_request
=
None
)
await
tokenizer_group_pool
.
encode_async
(
request_id
=
"1"
,
prompt
=
"prompt"
,
lora_request
=
None
)
# Actors should stay the same.
assert
tokenizer_group_pool
.
tokenizer_actors
==
tokenizer_actors
vllm/engine/async_llm_engine.py
View file @
67882dbb
...
@@ -310,6 +310,8 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -310,6 +310,8 @@ class _AsyncLLMEngine(LLMEngine):
)
)
async
def
check_health_async
(
self
)
->
None
:
async
def
check_health_async
(
self
)
->
None
:
if
self
.
tokenizer
:
self
.
tokenizer
.
check_health
()
self
.
model_executor
.
check_health
()
self
.
model_executor
.
check_health
()
...
...
vllm/engine/llm_engine.py
View file @
67882dbb
...
@@ -1013,6 +1013,8 @@ class LLMEngine:
...
@@ -1013,6 +1013,8 @@ class LLMEngine:
return
self
.
model_executor
.
pin_lora
(
lora_id
)
return
self
.
model_executor
.
pin_lora
(
lora_id
)
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
if
self
.
tokenizer
:
self
.
tokenizer
.
check_health
()
self
.
model_executor
.
check_health
()
self
.
model_executor
.
check_health
()
def
is_tracing_enabled
(
self
)
->
bool
:
def
is_tracing_enabled
(
self
)
->
bool
:
...
...
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
67882dbb
...
@@ -53,3 +53,7 @@ class BaseTokenizerGroup(ABC):
...
@@ -53,3 +53,7 @@ class BaseTokenizerGroup(ABC):
)
->
"PreTrainedTokenizer"
:
)
->
"PreTrainedTokenizer"
:
"""Get a tokenizer for a LoRA request."""
"""Get a tokenizer for a LoRA request."""
pass
pass
def
check_health
(
self
):
"""Raise exception if the tokenizer group is unhealthy."""
return
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
67882dbb
...
@@ -2,17 +2,21 @@ import asyncio
...
@@ -2,17 +2,21 @@ import asyncio
import
os
import
os
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
ray.exceptions
import
ActorDiedError
from
ray.util.scheduling_strategies
import
NodeAffinitySchedulingStrategy
from
ray.util.scheduling_strategies
import
NodeAffinitySchedulingStrategy
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
from
vllm.executor.ray_utils
import
ray
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
BaseTokenizerGroup
)
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
TokenizerGroup
)
TokenizerGroup
)
logger
=
init_logger
(
__name__
)
class
RayTokenizerGroupPool
(
BaseTokenizerGroup
):
class
RayTokenizerGroupPool
(
BaseTokenizerGroup
):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
"""A Ray-based pool of TokenizerGroups for async tokenization."""
...
@@ -46,24 +50,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -46,24 +50,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
ray_actor_options
:
dict
,
**
tokenizer_config
):
ray_actor_options
:
dict
,
**
tokenizer_config
):
# Store a local copy of the TokenizerGroup for quick access
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
# to underlying HF tokenizers.
self
.
_tokenizer_config
=
{
"tokenizer_id"
:
tokenizer_id
,
"enable_lora"
:
enable_lora
,
"max_num_seqs"
:
max_num_seqs
,
"max_input_length"
:
max_input_length
,
**
tokenizer_config
}
self
.
_local_tokenizer_group
=
self
.
_worker_cls
(
self
.
_local_tokenizer_group
=
self
.
_worker_cls
(
tokenizer_id
=
tokenizer_id
,
**
self
.
_tokenizer_config
,
)
enable_lora
=
enable_lora
,
max_num_seqs
=
max_num_seqs
,
self
.
_ray_tokenizer_group_cls
=
ray
.
remote
(
max_input_length
=
max_input_length
,
**
tokenizer_config
,
)
ray_tokenizer_group_cls
=
ray
.
remote
(
self
.
_worker_cls
).
options
(
**
ray_actor_options
)
self
.
_worker_cls
).
options
(
**
ray_actor_options
)
self
.
tokenizer_actors
=
[
self
.
tokenizer_actors
=
[
self
.
_init_actor
()
for
_
in
range
(
num_actors
)]
ray_tokenizer_group_cls
.
remote
(
tokenizer_id
,
enable_lora
,
max_num_seqs
,
max_input_length
,
**
tokenizer_config
)
for
_
in
range
(
num_actors
)
]
self
.
_idle_actors
:
Optional
[
asyncio
.
Queue
]
=
None
self
.
_idle_actors
:
Optional
[
asyncio
.
Queue
]
=
None
# If set, actor is unhealthy. Will reraise on the next
# check_health call.
self
.
_exception
:
Optional
[
ActorDiedError
]
=
None
def
_init_actor
(
self
)
->
ray
.
ObjectRef
:
return
self
.
_ray_tokenizer_group_cls
.
remote
(
**
self
.
_tokenizer_config
)
@
property
@
property
def
pool_size
(
self
)
->
int
:
def
pool_size
(
self
)
->
int
:
return
len
(
self
.
tokenizer_actors
)
return
len
(
self
.
tokenizer_actors
)
...
@@ -78,6 +86,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -78,6 +86,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
for
actor
in
self
.
tokenizer_actors
:
for
actor
in
self
.
tokenizer_actors
:
self
.
_idle_actors
.
put_nowait
(
actor
)
self
.
_idle_actors
.
put_nowait
(
actor
)
def
_finalize_encode
(
self
,
actor
:
ray
.
ObjectRef
,
original_actor
:
ray
.
ObjectRef
,
actor_is_alive
:
bool
):
assert
self
.
_idle_actors
is
not
None
# Cleanup the dead actor.
if
not
actor_is_alive
or
original_actor
is
not
actor
:
self
.
tokenizer_actors
.
remove
(
original_actor
)
if
actor_is_alive
:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
# Add back the new actor.
if
original_actor
is
not
actor
:
self
.
tokenizer_actors
.
append
(
actor
)
def
encode
(
self
,
def
encode
(
self
,
prompt
:
str
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
...
@@ -88,23 +112,41 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -88,23 +112,41 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
The actor is then put back in the queue for future use.
The actor is then put back in the queue for future use.
This is blocking.
This is blocking.
"""
"""
self
.
check_health
()
self
.
_ensure_queue_initialized
()
self
.
_ensure_queue_initialized
()
assert
self
.
_idle_actors
is
not
None
assert
self
.
_idle_actors
is
not
None
if
self
.
_idle_actors
.
empty
():
if
self
.
_idle_actors
.
empty
():
raise
RuntimeError
(
"No idle actors available."
)
raise
RuntimeError
(
"No idle actors available."
)
actor
=
self
.
_idle_actors
.
get_nowait
()
actor
=
self
.
_idle_actors
.
get_nowait
()
actor_is_alive
=
True
original_actor
=
actor
try
:
try
:
ret
=
ray
.
get
(
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
lora_request
=
lora_request
))
lora_request
=
lora_request
))
except
ActorDiedError
as
e
:
# If the actor is dead, we first try to reinitialize it.
logger
.
warning
(
"%s died with ActorDiedError, reinitializing."
,
actor
,
exc_info
=
e
)
actor
=
self
.
_init_actor
()
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
))
except
ActorDiedError
as
e
:
logger
.
error
(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy."
,
actor
)
actor_is_alive
=
False
if
not
self
.
_exception
:
self
.
_exception
=
e
self
.
check_health
()
finally
:
finally
:
# Put the actor back in the queue.
self
.
_finalize_encode
(
actor
,
original_actor
,
actor_is_alive
)
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
return
ret
return
ret
async
def
encode_async
(
async
def
encode_async
(
...
@@ -120,20 +162,37 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -120,20 +162,37 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
The actor is then put back in the queue for future use.
The actor is then put back in the queue for future use.
This is non-blocking.
This is non-blocking.
"""
"""
self
.
check_health
()
self
.
_ensure_queue_initialized
()
self
.
_ensure_queue_initialized
()
assert
self
.
_idle_actors
is
not
None
assert
self
.
_idle_actors
is
not
None
actor
=
await
self
.
_idle_actors
.
get
()
actor
=
await
self
.
_idle_actors
.
get
()
actor_is_alive
=
True
original_actor
=
actor
try
:
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
lora_request
=
lora_request
)
lora_request
=
lora_request
)
except
ActorDiedError
as
e
:
# If the actor is dead, we first try to reinitialize it.
logger
.
warning
(
"%s died with ActorDiedError, reinitializing."
,
actor
,
exc_info
=
e
)
actor
=
self
.
_init_actor
()
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
except
ActorDiedError
as
e
:
logger
.
error
(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy."
,
actor
)
actor_is_alive
=
False
if
not
self
.
_exception
:
self
.
_exception
=
e
self
.
check_health
()
finally
:
finally
:
# Put the actor back in the queue.
self
.
_finalize_encode
(
actor
,
original_actor
,
actor_is_alive
)
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
return
ret
return
ret
def
get_max_input_len
(
self
,
def
get_max_input_len
(
self
,
...
@@ -155,6 +214,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -155,6 +214,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return
await
self
.
_local_tokenizer_group
.
get_lora_tokenizer_async
(
return
await
self
.
_local_tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
lora_request
)
def
check_health
(
self
):
if
self
.
_exception
:
raise
RuntimeError
(
"TokenizerGroupPool is unhealthy."
)
from
self
.
_exception
def
_carry_over_env_vars_to_runtime_env
(
runtime_env
:
dict
)
->
None
:
def
_carry_over_env_vars_to_runtime_env
(
runtime_env
:
dict
)
->
None
:
"""Copy over all current process environment variables to the runtime_env.
"""Copy over all current process environment variables to the runtime_env.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment