Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb96c1e9
Unverified
Commit
fb96c1e9
authored
Mar 15, 2024
by
Antoni Baum
Committed by
GitHub
Mar 15, 2024
Browse files
Asynchronous tokenization (#2879)
parent
8fa7357f
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
658 additions
and
84 deletions
+658
-84
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-1
tests/async_engine/test_api_server.py
tests/async_engine/test_api_server.py
+7
-9
tests/conftest.py
tests/conftest.py
+11
-0
tests/lora/test_tokenizer_group.py
tests/lora/test_tokenizer_group.py
+53
-0
tests/tokenization/__init__.py
tests/tokenization/__init__.py
+0
-0
tests/tokenization/test_cached_tokenizer.py
tests/tokenization/test_cached_tokenizer.py
+20
-0
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+0
-0
tests/tokenization/test_tokenizer_group.py
tests/tokenization/test_tokenizer_group.py
+100
-0
vllm/config.py
vllm/config.py
+57
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+34
-9
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+11
-4
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+38
-61
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+32
-0
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+48
-0
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+166
-0
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+80
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
fb96c1e9
...
...
@@ -28,7 +28,7 @@ steps:
num_gpus
:
2
# only support 1 or 2 for now.
-
label
:
Engine Test
command
:
pytest -v -s engine test_sequence.py
command
:
pytest -v -s engine
tokenization
test_sequence.py
-
label
:
Entrypoints Test
command
:
pytest -v -s entrypoints
...
...
tests/async_engine/test_api_server.py
View file @
fb96c1e9
...
...
@@ -25,23 +25,21 @@ def _query_server_long(prompt: str) -> dict:
@
pytest
.
fixture
def
api_server
():
def
api_server
(
tokenizer_pool_size
:
int
):
script_path
=
Path
(
__file__
).
parent
.
joinpath
(
"api_server_async_engine.py"
).
absolute
()
uvicorn_process
=
subprocess
.
Popen
([
sys
.
executable
,
"-u"
,
str
(
script_path
),
"--model"
,
"facebook/opt-125m"
,
"--host"
,
"127.0.0.1"
,
sys
.
executable
,
"-u"
,
str
(
script_path
),
"--model"
,
"facebook/opt-125m"
,
"--host"
,
"127.0.0.1"
,
"--tokenizer-pool-size"
,
str
(
tokenizer_pool_size
)
])
yield
uvicorn_process
.
terminate
()
def
test_api_server
(
api_server
):
@
pytest
.
mark
.
parametrize
(
"tokenizer_pool_size"
,
[
0
,
2
])
def
test_api_server
(
api_server
,
tokenizer_pool_size
:
int
):
"""
Run the API server and test it.
...
...
tests/conftest.py
View file @
fb96c1e9
...
...
@@ -7,6 +7,7 @@ from transformers import AutoModelForCausalLM
from
vllm
import
LLM
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.config
import
TokenizerPoolConfig
_TEST_DIR
=
os
.
path
.
dirname
(
__file__
)
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
...
...
@@ -258,3 +259,13 @@ class VllmRunner:
@
pytest
.
fixture
def
vllm_runner
():
return
VllmRunner
def
get_tokenizer_pool_config
(
tokenizer_group_type
):
if
tokenizer_group_type
is
None
:
return
None
if
tokenizer_group_type
==
"ray"
:
return
TokenizerPoolConfig
(
pool_size
=
1
,
pool_type
=
"ray"
,
extra_config
=
{})
raise
ValueError
(
f
"Unknown tokenizer_group_type:
{
tokenizer_group_type
}
"
)
tests/lora/test_tokenizer.py
→
tests/lora/test_tokenizer
_group
.py
View file @
fb96c1e9
import
pytest
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
TokenizerGroup
,
get_lora_tokenizer
@
pytest
.
mark
.
asyncio
async
def
test_transformers_tokenizer
():
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer
=
TokenizerGroup
(
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer
.
encode
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
await
tokenizer
.
encode_async
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
assert
isinstance
(
tokenizer
.
get_lora_tokenizer
(
None
),
PreTrainedTokenizerBase
)
assert
tokenizer
.
get_lora_tokenizer
(
None
)
==
await
tokenizer
.
get_lora_tokenizer_async
(
None
)
from
vllm.transformers_utils.tokenizer_group
import
get_tokenizer_group
from
vllm.transformers_utils.tokenizer
import
get_lora_tokenizer
from
..conftest
import
get_tokenizer_pool_config
@
pytest
.
mark
.
asyncio
async
def
test_transformers_tokenizer_lora
(
sql_lora_files
):
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
None
,
"ray"
])
async
def
test_tokenizer_group_lora
(
sql_lora_files
,
tokenizer_group_type
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
sql_lora_files
)
tokenizer
=
TokenizerGroup
(
tokenizer_group
=
get_tokenizer_group
(
get_tokenizer_pool_config
(
tokenizer_group_type
),
tokenizer_id
=
"gpt2"
,
enable_lora
=
True
,
max_num_seqs
=
1
,
max_input_length
=
None
,
)
lora_request
=
LoRARequest
(
"1"
,
1
,
sql_lora_files
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer
.
encode
(
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer
_group
.
encode
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
lora_request
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
await
tokenizer
.
encode_async
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
lora_request
)
assert
isinstance
(
tokenizer
.
get_lora_tokenizer
(
None
),
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
lora_request
)
assert
isinstance
(
tokenizer_group
.
get_lora_tokenizer
(
None
),
PreTrainedTokenizerBase
)
assert
tokenizer
.
get_lora_tokenizer
(
None
)
==
await
tokenizer
.
get_lora_tokenizer_async
(
None
)
assert
tokenizer
_group
.
get_lora_tokenizer
(
None
)
==
await
tokenizer
_group
.
get_lora_tokenizer_async
(
None
)
assert
isinstance
(
tokenizer
.
get_lora_tokenizer
(
lora_request
),
assert
isinstance
(
tokenizer
_group
.
get_lora_tokenizer
(
lora_request
),
PreTrainedTokenizerBase
)
assert
tokenizer
.
get_lora_tokenizer
(
lora_request
)
!=
tokenizer
.
get_lora_tokenizer
(
None
)
assert
tokenizer
.
get_lora_tokenizer
(
lora_request
)
==
await
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
assert
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
!=
tokenizer_group
.
get_lora_tokenizer
(
None
)
assert
tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
==
await
tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
def
test_get_lora_tokenizer
(
sql_lora_files
,
tmpdir
):
...
...
tests/tokenization/__init__.py
0 → 100644
View file @
fb96c1e9
tests/tokenization/test_cached_tokenizer.py
0 → 100644
View file @
fb96c1e9
from
copy
import
deepcopy
from
vllm.transformers_utils.tokenizer
import
get_cached_tokenizer
from
transformers
import
AutoTokenizer
def
test_cached_tokenizer
():
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
reference_tokenizer
.
add_special_tokens
({
"cls_token"
:
"<CLS>"
})
reference_tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
"<SEP>"
]})
cached_tokenizer
=
get_cached_tokenizer
(
deepcopy
(
reference_tokenizer
))
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
cached_tokenizer
.
encode
(
"prompt"
)
assert
set
(
reference_tokenizer
.
all_special_ids
)
==
set
(
cached_tokenizer
.
all_special_ids
)
assert
set
(
reference_tokenizer
.
all_special_tokens
)
==
set
(
cached_tokenizer
.
all_special_tokens
)
assert
set
(
reference_tokenizer
.
all_special_tokens_extended
)
==
set
(
cached_tokenizer
.
all_special_tokens_extended
)
tests/
engine
/test_detokenize.py
→
tests/
tokenization
/test_detokenize.py
View file @
fb96c1e9
File moved
tests/tokenization/test_tokenizer_group.py
0 → 100644
View file @
fb96c1e9
import
os
import
pytest
import
asyncio
from
unittest.mock
import
patch
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.transformers_utils.tokenizer_group
import
get_tokenizer_group
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
RayTokenizerGroupPool
)
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
TokenizerGroup
)
from
..conftest
import
get_tokenizer_pool_config
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
None
,
"ray"
])
async
def
test_tokenizer_group
(
tokenizer_group_type
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer_group
=
get_tokenizer_group
(
get_tokenizer_pool_config
(
tokenizer_group_type
),
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
tokenizer_group
.
encode
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
assert
reference_tokenizer
.
encode
(
"prompt"
)
==
await
tokenizer_group
.
encode_async
(
request_id
=
"request_id"
,
prompt
=
"prompt"
,
lora_request
=
None
)
assert
isinstance
(
tokenizer_group
.
get_lora_tokenizer
(
None
),
PreTrainedTokenizerBase
)
assert
tokenizer_group
.
get_lora_tokenizer
(
None
)
==
await
tokenizer_group
.
get_lora_tokenizer_async
(
None
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_pool
(
tokenizer_group_type
):
reference_tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer_group_pool
=
get_tokenizer_group
(
get_tokenizer_pool_config
(
tokenizer_group_type
),
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
,
)
# Send multiple requests to the tokenizer group pool
# (more than the pool size)
# and check that all requests are processed correctly.
num_requests
=
tokenizer_group_pool
.
pool_size
*
5
requests
=
[
tokenizer_group_pool
.
encode_async
(
request_id
=
str
(
i
),
prompt
=
f
"prompt
{
i
}
"
,
lora_request
=
None
)
for
i
in
range
(
num_requests
)
]
results
=
await
asyncio
.
gather
(
*
requests
)
expected_results
=
[
reference_tokenizer
.
encode
(
f
"prompt
{
i
}
"
)
for
i
in
range
(
num_requests
)
]
assert
results
==
expected_results
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tokenizer_group_type"
,
[
"ray"
])
async
def
test_tokenizer_group_ray_pool_env_var_propagation
(
tokenizer_group_type
):
"""Test that env vars from caller process are propagated to
tokenizer Ray actors."""
env_var
=
"MY_ENV_VAR"
class
EnvVarCheckerTokenizerGroup
(
TokenizerGroup
):
def
ping
(
self
):
assert
os
.
environ
.
get
(
env_var
)
==
"1"
return
super
().
ping
()
class
EnvVarCheckerRayTokenizerGroupPool
(
RayTokenizerGroupPool
):
_worker_cls
=
EnvVarCheckerTokenizerGroup
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_pool
=
EnvVarCheckerRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
)
with
pytest
.
raises
(
AssertionError
):
tokenizer_pool
.
ping
()
with
patch
.
dict
(
os
.
environ
,
{
env_var
:
"1"
}):
tokenizer_pool_config
=
get_tokenizer_pool_config
(
tokenizer_group_type
)
tokenizer_pool
=
EnvVarCheckerRayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
tokenizer_id
=
"gpt2"
,
enable_lora
=
False
,
max_num_seqs
=
1
,
max_input_length
=
None
)
tokenizer_pool
.
ping
()
vllm/config.py
View file @
fb96c1e9
...
...
@@ -3,6 +3,7 @@ from dataclasses import dataclass
import
os
from
packaging.version
import
Version
import
json
import
torch
from
transformers
import
PretrainedConfig
...
...
@@ -389,6 +390,58 @@ class CacheConfig:
logger
.
warning
(
"Possibly too large swap space. "
+
msg
)
@
dataclass
class
TokenizerPoolConfig
:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size
:
int
pool_type
:
str
extra_config
:
dict
def
__post_init__
(
self
):
if
self
.
pool_type
not
in
(
"ray"
,
):
raise
ValueError
(
f
"Unknown pool type:
{
self
.
pool_type
}
"
)
if
not
isinstance
(
self
.
extra_config
,
dict
):
raise
ValueError
(
"extra_config must be a dictionary."
)
@
classmethod
def
create_config
(
cls
,
tokenizer_pool_size
:
int
,
tokenizer_pool_type
:
str
,
tokenizer_pool_extra_config
:
Optional
[
Union
[
str
,
dict
]]
)
->
Optional
[
"TokenizerPoolConfig"
]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
"""
if
tokenizer_pool_size
:
if
isinstance
(
tokenizer_pool_extra_config
,
str
):
tokenizer_pool_extra_config_parsed
=
json
.
loads
(
tokenizer_pool_extra_config
)
else
:
tokenizer_pool_extra_config_parsed
=
(
tokenizer_pool_extra_config
or
{})
tokenizer_pool_config
=
cls
(
tokenizer_pool_size
,
tokenizer_pool_type
,
tokenizer_pool_extra_config_parsed
)
else
:
tokenizer_pool_config
=
None
return
tokenizer_pool_config
class
ParallelConfig
:
"""Configuration for the distributed execution.
...
...
@@ -403,6 +456,8 @@ class ParallelConfig:
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
tokenizer_pool_config: Config for the tokenizer pool.
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
"""
...
...
@@ -414,6 +469,7 @@ class ParallelConfig:
worker_use_ray
:
bool
,
max_parallel_loading_workers
:
Optional
[
int
]
=
None
,
disable_custom_all_reduce
:
bool
=
False
,
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
]
=
None
,
ray_workers_use_nsight
:
bool
=
False
,
placement_group
:
Optional
[
"PlacementGroup"
]
=
None
,
)
->
None
:
...
...
@@ -430,6 +486,7 @@ class ParallelConfig:
self
.
worker_use_ray
=
worker_use_ray
self
.
max_parallel_loading_workers
=
max_parallel_loading_workers
self
.
disable_custom_all_reduce
=
disable_custom_all_reduce
self
.
tokenizer_pool_config
=
tokenizer_pool_config
self
.
ray_workers_use_nsight
=
ray_workers_use_nsight
self
.
placement_group
=
placement_group
...
...
vllm/engine/arg_utils.py
View file @
fb96c1e9
...
...
@@ -4,7 +4,8 @@ from dataclasses import dataclass
from
typing
import
Optional
,
Tuple
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
)
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
,
TokenizerPoolConfig
)
@
dataclass
...
...
@@ -40,6 +41,9 @@ class EngineArgs:
enforce_eager
:
bool
=
False
max_context_len_to_capture
:
int
=
8192
disable_custom_all_reduce
:
bool
=
False
tokenizer_pool_size
:
int
=
0
tokenizer_pool_type
:
str
=
"ray"
tokenizer_pool_extra_config
:
Optional
[
dict
]
=
None
enable_lora
:
bool
=
False
max_loras
:
int
=
1
max_lora_rank
:
int
=
16
...
...
@@ -249,6 +253,25 @@ class EngineArgs:
action
=
'store_true'
,
default
=
EngineArgs
.
disable_custom_all_reduce
,
help
=
'See ParallelConfig'
)
parser
.
add_argument
(
'--tokenizer-pool-size'
,
type
=
int
,
default
=
EngineArgs
.
tokenizer_pool_size
,
help
=
'Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.'
)
parser
.
add_argument
(
'--tokenizer-pool-type'
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_pool_type
,
help
=
'Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.'
)
parser
.
add_argument
(
'--tokenizer-pool-extra-config'
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_pool_extra_config
,
help
=
'Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.'
)
# LoRA related configs
parser
.
add_argument
(
'--enable-lora'
,
action
=
'store_true'
,
...
...
@@ -312,14 +335,16 @@ class EngineArgs:
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
model_config
.
get_sliding_window
(),
self
.
enable_prefix_caching
)
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
,
self
.
max_parallel_loading_workers
,
self
.
disable_custom_all_reduce
,
self
.
ray_workers_use_nsight
)
model_config
.
get_sliding_window
())
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
,
self
.
max_parallel_loading_workers
,
self
.
disable_custom_all_reduce
,
TokenizerPoolConfig
.
create_config
(
self
.
tokenizer_pool_size
,
self
.
tokenizer_pool_type
,
self
.
tokenizer_pool_extra_config
,
),
self
.
ray_workers_use_nsight
)
scheduler_config
=
SchedulerConfig
(
self
.
max_num_batched_tokens
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
...
...
vllm/engine/llm_engine.py
View file @
fb96c1e9
...
...
@@ -17,8 +17,9 @@ from vllm.outputs import RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
TokenizerGroup
)
from
vllm.transformers_utils.tokenizer
import
detokenize_incrementally
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
get_tokenizer_group
)
from
vllm.utils
import
Counter
logger
=
init_logger
(
__name__
)
...
...
@@ -102,6 +103,10 @@ class LLMEngine:
parallel_config
,
scheduler_config
,
device_config
,
lora_config
)
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self
.
tokenizer
.
ping
()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
...
...
@@ -152,6 +157,7 @@ class LLMEngine:
def
_init_tokenizer
(
self
,
**
tokenizer_init_kwargs
):
init_kwargs
=
dict
(
tokenizer_id
=
self
.
model_config
.
tokenizer
,
enable_lora
=
bool
(
self
.
lora_config
),
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
,
max_input_length
=
None
,
...
...
@@ -159,8 +165,9 @@ class LLMEngine:
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
revision
=
self
.
model_config
.
tokenizer_revision
)
init_kwargs
.
update
(
tokenizer_init_kwargs
)
self
.
tokenizer
:
TokenizerGroup
=
TokenizerGroup
(
self
.
model_config
.
tokenizer
,
**
init_kwargs
)
self
.
tokenizer
:
BaseTokenizerGroup
=
get_tokenizer_group
(
self
.
parallel_config
.
tokenizer_pool_config
,
**
init_kwargs
)
def
_verify_args
(
self
)
->
None
:
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
...
...
vllm/transformers_utils/tokenizer.py
View file @
fb96c1e9
...
...
@@ -5,12 +5,48 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
make_async
,
LRUCache
from
vllm.utils
import
make_async
from
vllm.transformers_utils.tokenizers
import
*
logger
=
init_logger
(
__name__
)
def
get_cached_tokenizer
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
function caches these properties for faster access."""
tokenizer_all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
tokenizer_all_special_tokens_extended
=
(
tokenizer
.
all_special_tokens_extended
)
tokenizer_all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
class
CachedTokenizer
(
tokenizer
.
__class__
):
@
property
def
all_special_ids
(
self
):
return
tokenizer_all_special_ids
@
property
def
all_special_tokens
(
self
):
return
tokenizer_all_special_tokens
@
property
def
all_special_tokens_extended
(
self
):
return
tokenizer_all_special_tokens_extended
CachedTokenizer
.
__name__
=
f
"Cached
{
tokenizer
.
__class__
.
__name__
}
"
tokenizer
.
__class__
=
CachedTokenizer
return
tokenizer
def
get_tokenizer
(
tokenizer_name
:
str
,
*
args
,
...
...
@@ -64,7 +100,7 @@ def get_tokenizer(
logger
.
warning
(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
return
tokenizer
return
get_cached_tokenizer
(
tokenizer
)
def
get_lora_tokenizer
(
lora_request
:
LoRARequest
,
*
args
,
...
...
@@ -88,65 +124,6 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
get_lora_tokenizer_async
=
make_async
(
get_lora_tokenizer
)
class
TokenizerGroup
:
"""A group of tokenizers that can be used for LoRA adapters."""
def
__init__
(
self
,
tokenizer_id
:
str
,
enable_lora
:
bool
,
max_num_seqs
:
int
,
max_input_length
:
Optional
[
int
],
**
tokenizer_config
):
self
.
tokenizer_id
=
tokenizer_id
self
.
tokenizer_config
=
tokenizer_config
self
.
enable_lora
=
enable_lora
self
.
max_input_length
=
max_input_length
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
if
enable_lora
:
self
.
lora_tokenizers
=
LRUCache
(
capacity
=
max_num_seqs
)
else
:
self
.
lora_tokenizers
=
None
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
tokenizer
=
self
.
get_lora_tokenizer
(
lora_request
)
return
tokenizer
.
encode
(
prompt
)
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
tokenizer
=
await
self
.
get_lora_tokenizer_async
(
lora_request
)
return
tokenizer
.
encode
(
prompt
)
def
get_lora_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
tokenizer
=
(
get_lora_tokenizer
(
lora_request
,
**
self
.
tokenizer_config
)
or
self
.
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
async
def
get_lora_tokenizer_async
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
tokenizer
=
(
await
get_lora_tokenizer_async
(
lora_request
,
**
self
.
tokenizer_config
)
or
self
.
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
def
_convert_tokens_to_string_with_added_encoders
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
output_tokens
:
List
[
str
],
...
...
vllm/transformers_utils/tokenizer_group/__init__.py
0 → 100644
View file @
fb96c1e9
from
typing
import
Optional
from
vllm.config
import
TokenizerPoolConfig
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
TokenizerGroup
)
from
vllm.engine.ray_utils
import
ray
if
ray
:
from
vllm.transformers_utils.tokenizer_group.ray_tokenizer_group
import
(
RayTokenizerGroupPool
)
else
:
RayTokenizerGroupPool
=
None
def
get_tokenizer_group
(
tokenizer_pool_config
:
Optional
[
TokenizerPoolConfig
],
**
init_kwargs
)
->
BaseTokenizerGroup
:
if
tokenizer_pool_config
is
None
:
return
TokenizerGroup
(
**
init_kwargs
)
if
tokenizer_pool_config
.
pool_type
==
"ray"
:
if
RayTokenizerGroupPool
is
None
:
raise
ImportError
(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool."
)
return
RayTokenizerGroupPool
.
from_config
(
tokenizer_pool_config
,
**
init_kwargs
)
else
:
raise
ValueError
(
f
"Unknown pool type:
{
tokenizer_pool_config
.
pool_type
}
"
)
__all__
=
[
"get_tokenizer_group"
,
"BaseTokenizerGroup"
]
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
0 → 100644
View file @
fb96c1e9
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.lora.request
import
LoRARequest
class
BaseTokenizerGroup
(
ABC
):
"""A group of tokenizers that can be used for LoRA adapters."""
@
abstractmethod
def
ping
(
self
)
->
bool
:
"""Check if the tokenizer group is alive."""
pass
@
abstractmethod
def
get_max_input_len
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
Optional
[
int
]:
"""Get the maximum input length for the LoRA request."""
pass
@
abstractmethod
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
],
lora_request
:
Optional
[
LoRARequest
])
->
List
[
int
]:
"""Encode a prompt using the tokenizer group."""
pass
@
abstractmethod
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
],
lora_request
:
Optional
[
LoRARequest
])
->
List
[
int
]:
"""Encode a prompt using the tokenizer group."""
pass
@
abstractmethod
def
get_lora_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
"PreTrainedTokenizer"
:
"""Get a tokenizer for a LoRA request."""
pass
@
abstractmethod
async
def
get_lora_tokenizer_async
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
"PreTrainedTokenizer"
:
"""Get a tokenizer for a LoRA request."""
pass
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
0 → 100644
View file @
fb96c1e9
import
asyncio
import
os
from
typing
import
List
,
Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
TokenizerPoolConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.engine.ray_utils
import
ray
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
vllm.transformers_utils.tokenizer_group.tokenizer_group
import
(
TokenizerGroup
)
from
ray.util.scheduling_strategies
import
NodeAffinitySchedulingStrategy
class
RayTokenizerGroupPool
(
BaseTokenizerGroup
):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
# Class to use for workers making up the pool.
_worker_cls
=
TokenizerGroup
@
classmethod
def
from_config
(
cls
,
tokenizer_pool_config
:
TokenizerPoolConfig
,
**
init_kwargs
)
->
"RayTokenizerGroupPool"
:
ray_actor_options
=
(
tokenizer_pool_config
.
extra_config
or
{
"num_cpus"
:
0
})
ray_actor_options
.
setdefault
(
"scheduling_strategy"
,
NodeAffinitySchedulingStrategy
(
node_id
=
ray
.
get_runtime_context
().
get_node_id
(),
soft
=
True
))
# Carry over the env vars to the actors.
# This is necessary for API keys and such.
ray_actor_options
.
setdefault
(
"runtime_env"
,
{})
_carry_over_env_vars_to_runtime_env
(
ray_actor_options
[
"runtime_env"
])
init_kwargs
[
"num_actors"
]
=
tokenizer_pool_config
.
pool_size
init_kwargs
[
"ray_actor_options"
]
=
ray_actor_options
return
cls
(
**
init_kwargs
)
def
__init__
(
self
,
tokenizer_id
:
str
,
enable_lora
:
bool
,
max_num_seqs
:
int
,
max_input_length
:
Optional
[
int
],
num_actors
:
int
,
ray_actor_options
:
dict
,
**
tokenizer_config
):
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
self
.
_local_tokenizer_group
=
self
.
_worker_cls
(
tokenizer_id
=
tokenizer_id
,
enable_lora
=
enable_lora
,
max_num_seqs
=
max_num_seqs
,
max_input_length
=
max_input_length
,
)
ray_tokenizer_group_cls
=
ray
.
remote
(
self
.
_worker_cls
).
options
(
**
ray_actor_options
)
self
.
tokenizer_actors
=
[
ray_tokenizer_group_cls
.
remote
(
tokenizer_id
,
enable_lora
,
max_num_seqs
,
max_input_length
,
**
tokenizer_config
)
for
_
in
range
(
num_actors
)
]
self
.
_idle_actors
:
Optional
[
asyncio
.
Queue
]
=
None
@
property
def
pool_size
(
self
)
->
int
:
return
len
(
self
.
tokenizer_actors
)
def
ping
(
self
):
return
ray
.
get
(
[
actor
.
ping
.
remote
()
for
actor
in
self
.
tokenizer_actors
])
def
_ensure_queue_initialized
(
self
):
if
self
.
_idle_actors
is
None
:
self
.
_idle_actors
=
asyncio
.
Queue
()
for
actor
in
self
.
tokenizer_actors
:
self
.
_idle_actors
.
put_nowait
(
actor
)
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
The actor is then put back in the queue for future use.
This is blocking.
"""
self
.
_ensure_queue_initialized
()
if
self
.
_idle_actors
.
empty
():
raise
RuntimeError
(
"No idle actors available."
)
actor
=
self
.
_idle_actors
.
get_nowait
()
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
))
finally
:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
return
ret
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
If there are no idle actors, we wait until one becomes
available.
The actor is then put back in the queue for future use.
This is non-blocking.
"""
self
.
_ensure_queue_initialized
()
actor
=
await
self
.
_idle_actors
.
get
()
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
finally
:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self
.
_idle_actors
.
put_nowait
(
actor
)
return
ret
def
get_max_input_len
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
Optional
[
int
]:
"""Get the maximum input length for the LoRA request."""
return
self
.
_local_tokenizer_group
.
get_max_input_len
(
lora_request
)
def
get_lora_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
return
self
.
_local_tokenizer_group
.
get_lora_tokenizer
(
lora_request
)
async
def
get_lora_tokenizer_async
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
return
await
self
.
_local_tokenizer_group
.
get_lora_tokenizer_async
(
lora_request
)
def
_carry_over_env_vars_to_runtime_env
(
runtime_env
:
dict
)
->
None
:
"""Copy over all current process environment variables to the runtime_env.
The variables in runtime_env will take precedence over the current process
environment variables.
runtime_env will be modified in place."""
env_vars
=
os
.
environ
.
copy
()
runtime_env
.
setdefault
(
"env_vars"
,
{})
env_vars
.
update
(
runtime_env
[
"env_vars"
])
runtime_env
[
"env_vars"
]
=
env_vars
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
0 → 100644
View file @
fb96c1e9
from
typing
import
List
,
Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer
import
(
get_lora_tokenizer
,
get_lora_tokenizer_async
)
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
from
vllm.utils
import
LRUCache
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
class
TokenizerGroup
(
BaseTokenizerGroup
):
"""A group of tokenizers that can be used for LoRA adapters."""
def
__init__
(
self
,
tokenizer_id
:
str
,
enable_lora
:
bool
,
max_num_seqs
:
int
,
max_input_length
:
Optional
[
int
],
**
tokenizer_config
):
self
.
tokenizer_id
=
tokenizer_id
self
.
tokenizer_config
=
tokenizer_config
self
.
enable_lora
=
enable_lora
self
.
max_input_length
=
max_input_length
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
if
enable_lora
:
self
.
lora_tokenizers
=
LRUCache
(
capacity
=
max_num_seqs
)
else
:
self
.
lora_tokenizers
=
None
def
ping
(
self
)
->
bool
:
"""Check if the tokenizer group is alive."""
return
True
def
get_max_input_len
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
Optional
[
int
]:
"""Get the maximum input length for the LoRA request."""
return
self
.
max_input_length
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
tokenizer
=
self
.
get_lora_tokenizer
(
lora_request
)
return
tokenizer
.
encode
(
prompt
)
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
List
[
int
]:
tokenizer
=
await
self
.
get_lora_tokenizer_async
(
lora_request
)
return
tokenizer
.
encode
(
prompt
)
def
get_lora_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
tokenizer
=
(
get_lora_tokenizer
(
lora_request
,
**
self
.
tokenizer_config
)
or
self
.
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
async
def
get_lora_tokenizer_async
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
"PreTrainedTokenizer"
:
if
not
lora_request
or
not
self
.
enable_lora
:
return
self
.
tokenizer
if
lora_request
.
lora_int_id
not
in
self
.
lora_tokenizers
:
tokenizer
=
(
await
get_lora_tokenizer_async
(
lora_request
,
**
self
.
tokenizer_config
)
or
self
.
tokenizer
)
self
.
lora_tokenizers
.
put
(
lora_request
.
lora_int_id
,
tokenizer
)
return
tokenizer
else
:
return
self
.
lora_tokenizers
.
get
(
lora_request
.
lora_int_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment