Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7bd82002
"docs/contributing/overview.md" did not exist on "dd6a3a02cb3bf2a7bc6cb84c85dcd57c6eaf2bf9"
Unverified
Commit
7bd82002
authored
Jul 19, 2024
by
Antoni Baum
Committed by
GitHub
Jul 20, 2024
Browse files
[Core] Allow specifying custom Executor (#6557)
parent
2e265642
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
285 additions
and
85 deletions
+285
-85
tests/conftest.py
tests/conftest.py
+4
-0
tests/engine/test_custom_executor.py
tests/engine/test_custom_executor.py
+91
-0
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+17
-4
vllm/config.py
vllm/config.py
+28
-11
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+15
-3
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+34
-19
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+27
-13
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+2
-0
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+2
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+2
-0
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+2
-0
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+2
-0
vllm/executor/openvino_executor.py
vllm/executor/openvino_executor.py
+2
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+21
-18
vllm/executor/ray_xpu_executor.py
vllm/executor/ray_xpu_executor.py
+13
-11
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+2
-0
vllm/executor/xpu_executor.py
vllm/executor/xpu_executor.py
+2
-0
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+9
-5
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+7
-0
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+3
-1
No files found.
tests/conftest.py
View file @
7bd82002
...
@@ -564,6 +564,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
...
@@ -564,6 +564,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
return
TokenizerPoolConfig
(
pool_size
=
1
,
return
TokenizerPoolConfig
(
pool_size
=
1
,
pool_type
=
"ray"
,
pool_type
=
"ray"
,
extra_config
=
{})
extra_config
=
{})
if
isinstance
(
tokenizer_group_type
,
type
):
return
TokenizerPoolConfig
(
pool_size
=
1
,
pool_type
=
tokenizer_group_type
,
extra_config
=
{})
raise
ValueError
(
f
"Unknown tokenizer_group_type:
{
tokenizer_group_type
}
"
)
raise
ValueError
(
f
"Unknown tokenizer_group_type:
{
tokenizer_group_type
}
"
)
...
...
tests/engine/test_custom_executor.py
0 → 100644
View file @
7bd82002
import
asyncio
import
os
import
pytest
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.executor.gpu_executor
import
GPUExecutor
,
GPUExecutorAsync
from
vllm.sampling_params
import
SamplingParams
class
Mock
:
...
class
CustomGPUExecutor
(
GPUExecutor
):
def
execute_model
(
self
,
*
args
,
**
kwargs
):
# Drop marker to show that this was ran
with
open
(
".marker"
,
"w"
):
...
return
super
().
execute_model
(
*
args
,
**
kwargs
)
class
CustomGPUExecutorAsync
(
GPUExecutorAsync
):
async
def
execute_model_async
(
self
,
*
args
,
**
kwargs
):
with
open
(
".marker"
,
"w"
):
...
return
await
super
().
execute_model_async
(
*
args
,
**
kwargs
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
def
test_custom_executor_type_checking
(
model
):
with
pytest
.
raises
(
ValueError
):
engine_args
=
EngineArgs
(
model
=
model
,
distributed_executor_backend
=
Mock
)
LLMEngine
.
from_engine_args
(
engine_args
)
with
pytest
.
raises
(
ValueError
):
engine_args
=
AsyncEngineArgs
(
model
=
model
,
distributed_executor_backend
=
Mock
)
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
with
pytest
.
raises
(
TypeError
):
engine_args
=
AsyncEngineArgs
(
model
=
model
,
distributed_executor_backend
=
CustomGPUExecutor
)
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
def
test_custom_executor
(
model
,
tmpdir
):
cwd
=
os
.
path
.
abspath
(
"."
)
os
.
chdir
(
tmpdir
)
try
:
assert
not
os
.
path
.
exists
(
".marker"
)
engine_args
=
EngineArgs
(
model
=
model
,
distributed_executor_backend
=
CustomGPUExecutor
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
engine
.
add_request
(
"0"
,
"foo"
,
sampling_params
)
engine
.
step
()
assert
os
.
path
.
exists
(
".marker"
)
finally
:
os
.
chdir
(
cwd
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
def
test_custom_executor_async
(
model
,
tmpdir
):
cwd
=
os
.
path
.
abspath
(
"."
)
os
.
chdir
(
tmpdir
)
try
:
assert
not
os
.
path
.
exists
(
".marker"
)
engine_args
=
AsyncEngineArgs
(
model
=
model
,
distributed_executor_backend
=
CustomGPUExecutorAsync
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
async
def
t
():
stream
=
await
engine
.
add_request
(
"0"
,
"foo"
,
sampling_params
)
async
for
x
in
stream
:
...
asyncio
.
run
(
t
())
assert
os
.
path
.
exists
(
".marker"
)
finally
:
os
.
chdir
(
cwd
)
tests/tokenization/test_tokenizer_group.py
View file @
7bd82002
...
@@ -7,17 +7,28 @@ from unittest.mock import patch
...
@@ -7,17 +7,28 @@ from unittest.mock import patch
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.transformers_utils.tokenizer_group
import
get_tokenizer_group
from
vllm.transformers_utils.tokenizer_group
import
(
TokenizerGroup
,
get_tokenizer_group
)
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
RayTokenizerGroupPool
)
RayTokenizerGroupPool
)
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
TokenizerGroup
)
from
..conftest
import
get_tokenizer_pool_config
from
..conftest
import
get_tokenizer_pool_config
class
CustomTokenizerGroup
(
TokenizerGroup
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_i
=
0
def
encode
(
self
,
*
args
,
**
kwargs
):
self
.
_i
+=
1
return
super
().
encode
(
*
args
,
**
kwargs
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
None
,
"ray"
])
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
None
,
"ray"
,
CustomTokenizerGroup
])
async
def
test_tokenizer_group
(
tokenizer_group_type
):
async
def
test_tokenizer_group
(
tokenizer_group_type
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer_group
=
get_tokenizer_group
(
tokenizer_group
=
get_tokenizer_group
(
...
@@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type):
...
@@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type):
PreTrainedTokenizerBase
)
PreTrainedTokenizerBase
)
assert
tokenizer_group
.
get_lora_tokenizer
(
assert
tokenizer_group
.
get_lora_tokenizer
(
None
)
==
await
tokenizer_group
.
get_lora_tokenizer_async
(
None
)
None
)
==
await
tokenizer_group
.
get_lora_tokenizer_async
(
None
)
if
tokenizer_group_type
is
CustomTokenizerGroup
:
assert
tokenizer_group
.
_i
>
0
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
vllm/config.py
View file @
7bd82002
import
enum
import
enum
import
json
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
ClassVar
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -18,7 +18,10 @@ from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
...
@@ -18,7 +18,10 @@ from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.model_executor.model_loader.loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.loader
import
BaseModelLoader
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -527,11 +530,12 @@ class TokenizerPoolConfig:
...
@@ -527,11 +530,12 @@ class TokenizerPoolConfig:
pool type.
pool type.
"""
"""
pool_size
:
int
pool_size
:
int
pool_type
:
str
pool_type
:
Union
[
str
,
Type
[
"BaseTokenizerGroup"
]]
extra_config
:
dict
extra_config
:
dict
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
pool_type
not
in
(
"ray"
,
):
if
self
.
pool_type
not
in
(
"ray"
,
)
and
not
isinstance
(
self
.
pool_type
,
type
):
raise
ValueError
(
f
"Unknown pool type:
{
self
.
pool_type
}
"
)
raise
ValueError
(
f
"Unknown pool type:
{
self
.
pool_type
}
"
)
if
not
isinstance
(
self
.
extra_config
,
dict
):
if
not
isinstance
(
self
.
extra_config
,
dict
):
raise
ValueError
(
"extra_config must be a dictionary."
)
raise
ValueError
(
"extra_config must be a dictionary."
)
...
@@ -661,7 +665,8 @@ class ParallelConfig:
...
@@ -661,7 +665,8 @@ class ParallelConfig:
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
=
None
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
=
None
,
ray_workers_use_nsight
:
bool
=
False
,
ray_workers_use_nsight
:
bool
=
False
,
placement_group
:
Optional
[
"PlacementGroup"
]
=
None
,
placement_group
:
Optional
[
"PlacementGroup"
]
=
None
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
distributed_executor_backend
:
Optional
[
Union
[
str
,
Type
[
"ExecutorBase"
]]]
=
None
,
)
->
None
:
)
->
None
:
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
pipeline_parallel_size
=
pipeline_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
tensor_parallel_size
=
tensor_parallel_size
...
@@ -676,7 +681,7 @@ class ParallelConfig:
...
@@ -676,7 +681,7 @@ class ParallelConfig:
if
worker_use_ray
:
if
worker_use_ray
:
if
self
.
distributed_executor_backend
is
None
:
if
self
.
distributed_executor_backend
is
None
:
self
.
distributed_executor_backend
=
"ray"
self
.
distributed_executor_backend
=
"ray"
elif
self
.
distributed_executor_backend
!=
"
ray
"
:
elif
not
self
.
use_
ray
:
raise
ValueError
(
f
"worker-use-ray can't be used with "
raise
ValueError
(
f
"worker-use-ray can't be used with "
f
"distributed executor backend "
f
"distributed executor backend "
f
"'
{
self
.
distributed_executor_backend
}
'."
)
f
"'
{
self
.
distributed_executor_backend
}
'."
)
...
@@ -711,12 +716,25 @@ class ParallelConfig:
...
@@ -711,12 +716,25 @@ class ParallelConfig:
self
.
_verify_args
()
self
.
_verify_args
()
self
.
rank
=
0
self
.
rank
=
0
@
property
def
use_ray
(
self
)
->
bool
:
return
self
.
distributed_executor_backend
==
"ray"
or
(
isinstance
(
self
.
distributed_executor_backend
,
type
)
and
self
.
distributed_executor_backend
.
uses_ray
)
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
if
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
None
):
# Lazy import to avoid circular import
from
vllm.executor.executor_base
import
ExecutorBase
if
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
None
)
and
not
(
isinstance
(
self
.
distributed_executor_backend
,
type
)
and
issubclass
(
self
.
distributed_executor_backend
,
ExecutorBase
)):
raise
ValueError
(
raise
ValueError
(
"Unrecognized distributed executor backend. Supported values "
"Unrecognized distributed executor backend "
"are 'ray' or 'mp'."
)
f
"
{
self
.
distributed_executor_backend
}
. Supported "
if
self
.
distributed_executor_backend
==
"ray"
:
"values are 'ray', 'mp' or custom ExecutorBase subclass."
)
if
self
.
use_ray
:
from
vllm.executor
import
ray_utils
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
ray_utils
.
assert_ray_available
()
if
is_hip
():
if
is_hip
():
...
@@ -724,8 +742,7 @@ class ParallelConfig:
...
@@ -724,8 +742,7 @@ class ParallelConfig:
logger
.
info
(
logger
.
info
(
"Disabled the custom all-reduce kernel because it is not "
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs."
)
"supported on AMD GPUs."
)
if
self
.
ray_workers_use_nsight
and
(
if
self
.
ray_workers_use_nsight
and
not
self
.
use_ray
:
not
self
.
distributed_executor_backend
==
"ray"
):
raise
ValueError
(
"Unable to use nsight profiling unless workers "
raise
ValueError
(
"Unable to use nsight profiling unless workers "
"run with Ray."
)
"run with Ray."
)
...
...
vllm/engine/arg_utils.py
View file @
7bd82002
...
@@ -2,16 +2,21 @@ import argparse
...
@@ -2,16 +2,21 @@ import argparse
import
dataclasses
import
dataclasses
import
json
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
)
SpeculativeConfig
,
TokenizerPoolConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
if
TYPE_CHECKING
:
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
def
nullable_str
(
val
:
str
):
def
nullable_str
(
val
:
str
):
if
not
val
or
val
==
"None"
:
if
not
val
or
val
==
"None"
:
...
@@ -36,7 +41,11 @@ class EngineArgs:
...
@@ -36,7 +41,11 @@ class EngineArgs:
seed
:
int
=
0
seed
:
int
=
0
max_model_len
:
Optional
[
int
]
=
None
max_model_len
:
Optional
[
int
]
=
None
worker_use_ray
:
bool
=
False
worker_use_ray
:
bool
=
False
distributed_executor_backend
:
Optional
[
str
]
=
None
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend
:
Optional
[
Union
[
str
,
Type
[
ExecutorBase
]]]
=
None
pipeline_parallel_size
:
int
=
1
pipeline_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
tensor_parallel_size
:
int
=
1
max_parallel_loading_workers
:
Optional
[
int
]
=
None
max_parallel_loading_workers
:
Optional
[
int
]
=
None
...
@@ -62,7 +71,10 @@ class EngineArgs:
...
@@ -62,7 +71,10 @@ class EngineArgs:
max_seq_len_to_capture
:
int
=
8192
max_seq_len_to_capture
:
int
=
8192
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
tokenizer_pool_size
:
int
=
0
tokenizer_pool_size
:
int
=
0
tokenizer_pool_type
:
str
=
"ray"
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type
:
Union
[
str
,
Type
[
"BaseTokenizerGroup"
]]
=
"ray"
tokenizer_pool_extra_config
:
Optional
[
dict
]
=
None
tokenizer_pool_extra_config
:
Optional
[
dict
]
=
None
enable_lora
:
bool
=
False
enable_lora
:
bool
=
False
max_loras
:
int
=
1
max_loras
:
int
=
1
...
...
vllm/engine/async_llm_engine.py
View file @
7bd82002
...
@@ -7,12 +7,13 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
...
@@ -7,12 +7,13 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -385,25 +386,19 @@ class AsyncLLMEngine:
...
@@ -385,25 +386,19 @@ class AsyncLLMEngine:
self
.
_request_tracker
:
RequestTracker
self
.
_request_tracker
:
RequestTracker
@
classmethod
@
classmethod
def
from_engine_args
(
def
_get_executor_cls
(
cls
,
cls
,
engine_config
:
EngineConfig
)
->
Type
[
ExecutorAsyncBase
]:
engine_args
:
AsyncEngineArgs
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
if
engine_args
.
engine_use_ray
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
distributed_executor_backend
=
(
distributed_executor_backend
=
(
engine_config
.
parallel_config
.
distributed_executor_backend
)
engine_config
.
parallel_config
.
distributed_executor_backend
)
if
isinstance
(
distributed_executor_backend
,
type
):
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
if
not
issubclass
(
distributed_executor_backend
,
ExecutorAsyncBase
):
raise
TypeError
(
"distributed_executor_backend must be a subclass of "
f
"ExecutorAsyncBase. Got
{
distributed_executor_backend
}
."
)
if
distributed_executor_backend
.
uses_ray
:
# type: ignore
initialize_ray_cluster
(
engine_config
.
parallel_config
)
executor_class
=
distributed_executor_backend
elif
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
...
@@ -442,9 +437,29 @@ class AsyncLLMEngine:
...
@@ -442,9 +437,29 @@ class AsyncLLMEngine:
else
:
else
:
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
executor_class
=
GPUExecutorAsync
executor_class
=
GPUExecutorAsync
return
executor_class
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
AsyncEngineArgs
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
if
engine_args
.
engine_use_ray
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
# Create the async LLM engine.
# Create the async LLM engine.
engine
=
cls
(
engine
=
cls
(
distributed_executor_backend
==
"
ray
"
,
executor_class
.
uses_
ray
,
engine_args
.
engine_use_ray
,
engine_args
.
engine_use_ray
,
**
engine_config
.
to_dict
(),
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
executor_class
=
executor_class
,
...
...
vllm/engine/llm_engine.py
View file @
7bd82002
...
@@ -7,9 +7,9 @@ from typing import Set, Type, TypeVar, Union
...
@@ -7,9 +7,9 @@ from typing import Set, Type, TypeVar, Union
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoadConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoRA
Config
,
Model
Config
,
Multi
Mod
a
lConfig
,
Engine
Config
,
Load
Config
,
LoRAConfig
,
Mod
e
lConfig
,
ObservabilityConfig
,
ParallelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
SpeculativeConfig
)
from
vllm.core.scheduler
import
(
ScheduledSequenceGroup
,
Scheduler
,
from
vllm.core.scheduler
import
(
ScheduledSequenceGroup
,
Scheduler
,
...
@@ -376,19 +376,20 @@ class LLMEngine:
...
@@ -376,19 +376,20 @@ class LLMEngine:
self
.
model_executor
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
self
.
model_executor
.
initialize_cache
(
num_gpu_blocks
,
num_cpu_blocks
)
@
classmethod
@
classmethod
def
from_engine_args
(
def
_get_executor_cls
(
cls
,
cls
,
engine_config
:
EngineConfig
)
->
Type
[
ExecutorBase
]:
engine_args
:
EngineArgs
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
distributed_executor_backend
=
(
distributed_executor_backend
=
(
engine_config
.
parallel_config
.
distributed_executor_backend
)
engine_config
.
parallel_config
.
distributed_executor_backend
)
# Initialize the cluster and specify the executor class.
# Initialize the cluster and specify the executor class.
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
if
isinstance
(
distributed_executor_backend
,
type
):
if
not
issubclass
(
distributed_executor_backend
,
ExecutorBase
):
raise
TypeError
(
"distributed_executor_backend must be a subclass of "
f
"ExecutorBase. Got
{
distributed_executor_backend
}
."
)
if
distributed_executor_backend
.
uses_ray
:
# type: ignore
initialize_ray_cluster
(
engine_config
.
parallel_config
)
executor_class
=
distributed_executor_backend
elif
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutor
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
executor_class
=
NeuronExecutor
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
...
@@ -422,6 +423,19 @@ class LLMEngine:
...
@@ -422,6 +423,19 @@ class LLMEngine:
else
:
else
:
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
executor_class
=
GPUExecutor
executor_class
=
GPUExecutor
return
executor_class
@
classmethod
def
from_engine_args
(
cls
,
engine_args
:
EngineArgs
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"LLMEngine"
:
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
# Create the LLM engine.
# Create the LLM engine.
engine
=
cls
(
engine
=
cls
(
**
engine_config
.
to_dict
(),
**
engine_config
.
to_dict
(),
...
...
vllm/executor/cpu_executor.py
View file @
7bd82002
...
@@ -17,6 +17,8 @@ logger = init_logger(__name__)
...
@@ -17,6 +17,8 @@ logger = init_logger(__name__)
class
CPUExecutor
(
ExecutorBase
):
class
CPUExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"cpu"
assert
self
.
device_config
.
device_type
==
"cpu"
assert
self
.
lora_config
is
None
,
"cpu backend doesn't support LoRA"
assert
self
.
lora_config
is
None
,
"cpu backend doesn't support LoRA"
...
...
vllm/executor/executor_base.py
View file @
7bd82002
...
@@ -18,6 +18,8 @@ class ExecutorBase(ABC):
...
@@ -18,6 +18,8 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices.
that can execute the model on multiple devices.
"""
"""
uses_ray
:
bool
# whether the executor uses Ray for orchestration.
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
...
vllm/executor/gpu_executor.py
View file @
7bd82002
...
@@ -23,6 +23,8 @@ def create_worker(worker_module_name, worker_class_name, **kwargs):
...
@@ -23,6 +23,8 @@ def create_worker(worker_module_name, worker_class_name, **kwargs):
class
GPUExecutor
(
ExecutorBase
):
class
GPUExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model.
"""Initialize the worker and load the model.
"""
"""
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
7bd82002
...
@@ -25,6 +25,8 @@ logger = init_logger(__name__)
...
@@ -25,6 +25,8 @@ logger = init_logger(__name__)
class
MultiprocessingGPUExecutor
(
DistributedGPUExecutor
):
class
MultiprocessingGPUExecutor
(
DistributedGPUExecutor
):
"""Python multiprocessing-based multi-GPU executor"""
"""Python multiprocessing-based multi-GPU executor"""
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
# Create the parallel GPU workers.
# Create the parallel GPU workers.
world_size
=
self
.
parallel_config
.
world_size
world_size
=
self
.
parallel_config
.
world_size
...
...
vllm/executor/neuron_executor.py
View file @
7bd82002
...
@@ -11,6 +11,8 @@ logger = init_logger(__name__)
...
@@ -11,6 +11,8 @@ logger = init_logger(__name__)
class
NeuronExecutor
(
ExecutorBase
):
class
NeuronExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
(
self
.
lora_config
is
assert
(
self
.
lora_config
is
None
),
"LoRA is not supported for Neuron backend."
None
),
"LoRA is not supported for Neuron backend."
...
...
vllm/executor/openvino_executor.py
View file @
7bd82002
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class
OpenVINOExecutor
(
ExecutorBase
):
class
OpenVINOExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"openvino"
assert
self
.
device_config
.
device_type
==
"openvino"
assert
self
.
lora_config
is
None
,
"OpenVINO backend doesn't support LoRA"
assert
self
.
lora_config
is
None
,
"OpenVINO backend doesn't support LoRA"
...
...
vllm/executor/ray_gpu_executor.py
View file @
7bd82002
...
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
...
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
class
RayGPUExecutor
(
DistributedGPUExecutor
):
class
RayGPUExecutor
(
DistributedGPUExecutor
):
uses_ray
:
bool
=
True
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
# If the env var is set, it uses the Ray's compiled DAG API
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# which optimizes the control plane overhead.
...
@@ -47,7 +49,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -47,7 +49,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1"
)
"VLLM_USE_RAY_COMPILED_DAG=1"
)
assert
self
.
parallel_config
.
distributed_executor_backend
==
"
ray
"
assert
self
.
uses_
ray
placement_group
=
self
.
parallel_config
.
placement_group
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
# Disable Ray usage stats collection.
...
@@ -75,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -75,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
return
ray_remote_kwargs
return
ray_remote_kwargs
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
if
self
.
speculative_config
is
not
None
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
return
dict
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
**
ray_remote_kwargs
):
if
(
self
.
parallel_config
.
tensor_parallel_size
==
1
if
(
self
.
parallel_config
.
tensor_parallel_size
==
1
...
@@ -97,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -97,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers.
# Create the workers.
driver_ip
=
get_ip
()
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
continue
...
@@ -106,23 +123,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -106,23 +123,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_bundle_index
=
bundle_id
,
placement_group_bundle_index
=
bundle_id
,
)
)
if
self
.
speculative_config
is
not
None
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
worker
=
ray
.
remote
(
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
)(
RayWorkerWrapper
).
remote
(
**
worker_wrapper_kwargs
)
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
if
self
.
use_ray_spmd_worker
:
if
self
.
use_ray_spmd_worker
:
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -133,10 +139,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -133,10 +139,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process.
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
self
.
driver_worker
=
RayWorkerWrapper
(
worker_module_name
=
worker_module_name
,
**
worker_wrapper_kwargs
)
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
else
:
else
:
# Else, added to the list of workers.
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -378,7 +381,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -378,7 +381,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
distributed_executor_backend
==
"
ray
"
assert
self
.
parallel_config
.
use_
ray
# Right now, compiled DAG requires at least 1 arg. We send
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
# a dummy value for now. It will be fixed soon.
...
...
vllm/executor/ray_xpu_executor.py
View file @
7bd82002
...
@@ -35,6 +35,8 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
...
@@ -35,6 +35,8 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class
RayXPUExecutor
(
DistributedGPUExecutor
):
class
RayXPUExecutor
(
DistributedGPUExecutor
):
uses_ray
:
bool
=
True
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
@@ -107,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -107,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
return
num_gpu_blocks
,
num_cpu_blocks
return
num_gpu_blocks
,
num_cpu_blocks
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
...
@@ -124,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -124,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the workers.
# Create the workers.
driver_ip
=
get_ip
()
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
continue
...
@@ -137,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -137,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
)(
RayWorkerWrapper
).
remote
(
**
worker_wrapper_kwargs
)
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
self
.
driver_worker
=
RayWorkerWrapper
(
**
worker_wrapper_kwargs
)
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
else
:
else
:
# Else, added to the list of workers.
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -337,7 +339,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -337,7 +339,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
distributed_executor_backend
==
"
ray
"
assert
self
.
parallel_config
.
use_
ray
# Right now, compiled DAG requires at least 1 arg. We send
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
# a dummy value for now. It will be fixed soon.
...
...
vllm/executor/tpu_executor.py
View file @
7bd82002
...
@@ -14,6 +14,8 @@ logger = init_logger(__name__)
...
@@ -14,6 +14,8 @@ logger = init_logger(__name__)
class
TPUExecutor
(
ExecutorBase
):
class
TPUExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
,
(
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
,
(
"Chunked prefill is not yet supported for TPU backend"
)
"Chunked prefill is not yet supported for TPU backend"
)
...
...
vllm/executor/xpu_executor.py
View file @
7bd82002
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class
XPUExecutor
(
GPUExecutor
):
class
XPUExecutor
(
GPUExecutor
):
uses_ray
:
bool
=
False
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
7bd82002
from
typing
import
Optional
from
typing
import
Optional
,
Type
from
vllm.config
import
TokenizerPoolConfig
from
vllm.config
import
TokenizerPoolConfig
from
vllm.executor.ray_utils
import
ray
from
vllm.executor.ray_utils
import
ray
...
@@ -16,18 +16,22 @@ else:
...
@@ -16,18 +16,22 @@ else:
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
BaseTokenizerGroup
:
**
init_kwargs
)
->
BaseTokenizerGroup
:
tokenizer_cls
:
Type
[
BaseTokenizerGroup
]
if
tokenizer_pool_config
is
None
:
if
tokenizer_pool_config
is
None
:
return
TokenizerGroup
(
**
init_kwargs
)
tokenizer_cls
=
TokenizerGroup
if
tokenizer_pool_config
.
pool_type
==
"ray"
:
elif
isinstance
(
tokenizer_pool_config
.
pool_type
,
type
)
and
issubclass
(
tokenizer_pool_config
.
pool_type
,
BaseTokenizerGroup
):
tokenizer_cls
=
tokenizer_pool_config
.
pool_type
elif
tokenizer_pool_config
.
pool_type
==
"ray"
:
if
RayTokenizerGroupPool
is
None
:
if
RayTokenizerGroupPool
is
None
:
raise
ImportError
(
raise
ImportError
(
"RayTokenizerGroupPool is not available. Please install "
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool."
)
"the ray package to use the Ray tokenizer group pool."
)
return
RayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_cls
=
RayTokenizerGroupPool
**
init_kwargs
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unknown pool type:
{
tokenizer_pool_config
.
pool_type
}
"
)
f
"Unknown pool type:
{
tokenizer_pool_config
.
pool_type
}
"
)
return
tokenizer_cls
.
from_config
(
tokenizer_pool_config
,
**
init_kwargs
)
__all__
=
[
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
__all__
=
[
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
7bd82002
...
@@ -3,12 +3,19 @@ from typing import List, Optional
...
@@ -3,12 +3,19 @@ from typing import List, Optional
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
class
BaseTokenizerGroup
(
ABC
):
class
BaseTokenizerGroup
(
ABC
):
"""A group of tokenizers that can be used for LoRA adapters."""
"""A group of tokenizers that can be used for LoRA adapters."""
@
classmethod
@
abstractmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
"BaseTokenizerGroup"
:
pass
@
abstractmethod
@
abstractmethod
def
ping
(
self
)
->
bool
:
def
ping
(
self
)
->
bool
:
"""Check if the tokenizer group is alive."""
"""Check if the tokenizer group is alive."""
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
7bd82002
...
@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
...
@@ -29,8 +29,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
_worker_cls
=
TokenizerGroup
_worker_cls
=
TokenizerGroup
@
classmethod
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
TokenizerPoolConfig
,
def
from_config
(
cls
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
,
**
init_kwargs
)
->
"RayTokenizerGroupPool"
:
**
init_kwargs
)
->
"RayTokenizerGroupPool"
:
if
not
tokenizer_pool_config
:
raise
ValueError
(
"tokenizer_pool_config must not be None."
)
ray_actor_options
=
(
tokenizer_pool_config
.
extra_config
or
{
ray_actor_options
=
(
tokenizer_pool_config
.
extra_config
or
{
"num_cpus"
:
0
"num_cpus"
:
0
})
})
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment