Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
4338cc47
Unverified
Commit
4338cc47
authored
Jun 28, 2023
by
Woosuk Kwon
Committed by
GitHub
Jun 28, 2023
Browse files
[Tokenizer] Add an option to specify tokenizer (#284)
parent
bdd6b4c8
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
61 additions
and
60 deletions
+61
-60
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+2
-0
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+2
-9
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+15
-20
vllm/config.py
vllm/config.py
+3
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+3
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+1
-1
vllm/transformers_utils/__init__.py
vllm/transformers_utils/__init__.py
+0
-0
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+25
-27
No files found.
benchmarks/benchmark_latency.py
View file @
4338cc47
...
@@ -17,6 +17,7 @@ def main(args: argparse.Namespace):
...
@@ -17,6 +17,7 @@ def main(args: argparse.Namespace):
# the engine will automatically process the request in multiple batches.
# the engine will automatically process the request in multiple batches.
llm
=
LLM
(
llm
=
LLM
(
model
=
args
.
model
,
model
=
args
.
model
,
tokenizer
=
args
.
tokenizer
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
tensor_parallel_size
=
args
.
tensor_parallel_size
,
max_num_seqs
=
args
.
batch_size
,
max_num_seqs
=
args
.
batch_size
,
max_num_batched_tokens
=
args
.
batch_size
*
args
.
input_len
,
max_num_batched_tokens
=
args
.
batch_size
*
args
.
input_len
,
...
@@ -63,6 +64,7 @@ if __name__ == '__main__':
...
@@ -63,6 +64,7 @@ if __name__ == '__main__':
description
=
'Benchmark the latency of processing a single batch of '
description
=
'Benchmark the latency of processing a single batch of '
'requests till completion.'
)
'requests till completion.'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
)
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
128
)
...
...
benchmarks/benchmark_serving.py
View file @
4338cc47
...
@@ -24,20 +24,13 @@ from typing import AsyncGenerator, List, Tuple
...
@@ -24,20 +24,13 @@ from typing import AsyncGenerator, List, Tuple
import
aiohttp
import
aiohttp
import
numpy
as
np
import
numpy
as
np
from
transformers
import
AutoConfig
,
AutoTokenizer
,
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
# (prompt len, output len, latency)
# (prompt len, output len, latency)
REQUEST_LATENCY
:
List
[
Tuple
[
int
,
int
,
float
]]
=
[]
REQUEST_LATENCY
:
List
[
Tuple
[
int
,
int
,
float
]]
=
[]
def
get_tokenizer
(
model_name
:
str
)
->
PreTrainedTokenizerBase
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
model_type
==
"llama"
:
# A workaround for potential protobuf errors.
model_name
=
"hf-internal-testing/llama-tokenizer"
return
AutoTokenizer
.
from_pretrained
(
model_name
)
def
sample_requests
(
def
sample_requests
(
dataset_path
:
str
,
dataset_path
:
str
,
num_requests
:
int
,
num_requests
:
int
,
...
...
benchmarks/benchmark_throughput.py
View file @
4338cc47
...
@@ -6,23 +6,11 @@ import time
...
@@ -6,23 +6,11 @@ import time
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
import
torch
import
torch
from
transformers
import
(
AutoConfig
,
AutoTokenizer
,
AutoModelForCausalLM
,
from
transformers
import
AutoModelForCausalLM
,
PreTrainedTokenizerBase
PreTrainedTokenizerBase
)
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
def
get_tokenizer
(
model_name
:
str
)
->
PreTrainedTokenizerBase
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
model_type
==
"llama"
:
# A workaround for potential protobuf errors.
model_name
=
"hf-internal-testing/llama-tokenizer"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
# To enable padding in the HF backend.
tokenizer
.
pad_token
=
tokenizer
.
eos_token
return
tokenizer
return
AutoTokenizer
.
from_pretrained
(
model_name
)
def
sample_requests
(
def
sample_requests
(
...
@@ -74,6 +62,7 @@ def sample_requests(
...
@@ -74,6 +62,7 @@ def sample_requests(
def
run_vllm
(
def
run_vllm
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
model
:
str
,
tokenizer
:
str
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
seed
:
int
,
seed
:
int
,
n
:
int
,
n
:
int
,
...
@@ -81,6 +70,7 @@ def run_vllm(
...
@@ -81,6 +70,7 @@ def run_vllm(
)
->
float
:
)
->
float
:
llm
=
LLM
(
llm
=
LLM
(
model
=
model
,
model
=
model
,
tokenizer
=
tokenizer
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
seed
,
seed
=
seed
,
)
)
...
@@ -118,9 +108,10 @@ def run_hf(
...
@@ -118,9 +108,10 @@ def run_hf(
max_batch_size
:
int
,
max_batch_size
:
int
,
)
->
float
:
)
->
float
:
assert
not
use_beam_search
assert
not
use_beam_search
tokenizer
=
get_tokenizer
(
model
)
llm
=
AutoModelForCausalLM
.
from_pretrained
(
model
,
torch_dtype
=
torch
.
float16
)
llm
=
AutoModelForCausalLM
.
from_pretrained
(
if
llm
.
config
.
model_type
==
"llama"
:
model
,
torch_dtype
=
torch
.
float16
)
# To enable padding in the HF backend.
tokenizer
.
pad_token
=
tokenizer
.
eos_token
llm
=
llm
.
cuda
()
llm
=
llm
.
cuda
()
pbar
=
tqdm
(
total
=
len
(
requests
))
pbar
=
tqdm
(
total
=
len
(
requests
))
...
@@ -170,13 +161,13 @@ def main(args: argparse.Namespace):
...
@@ -170,13 +161,13 @@ def main(args: argparse.Namespace):
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
# Sample the requests.
# Sample the requests.
tokenizer
=
get_tokenizer
(
args
.
model
)
tokenizer
=
get_tokenizer
(
args
.
tokenizer
)
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
if
args
.
backend
==
"vllm"
:
if
args
.
backend
==
"vllm"
:
elapsed_time
=
run_vllm
(
elapsed_time
=
run_vllm
(
requests
,
args
.
model
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
tensor_parallel_size
,
args
.
use_beam_search
)
args
.
seed
,
args
.
n
,
args
.
use_beam_search
)
elif
args
.
backend
==
"hf"
:
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
...
@@ -198,6 +189,7 @@ if __name__ == "__main__":
...
@@ -198,6 +189,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
required
=
True
,
help
=
"Path to the dataset."
)
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1
,
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
help
=
"Number of generated sequences per prompt."
)
...
@@ -208,11 +200,14 @@ if __name__ == "__main__":
...
@@ -208,11 +200,14 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--hf-max-batch-size"
,
type
=
int
,
default
=
None
,
parser
.
add_argument
(
"--hf-max-batch-size"
,
type
=
int
,
default
=
None
,
help
=
"Maximum batch size for HF backend."
)
help
=
"Maximum batch size for HF backend."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
backend
==
"vllm"
:
if
args
.
backend
==
"vllm"
:
if
args
.
hf_max_batch_size
is
not
None
:
if
args
.
hf_max_batch_size
is
not
None
:
raise
ValueError
(
"HF max batch size is only for HF backend."
)
raise
ValueError
(
"HF max batch size is only for HF backend."
)
elif
args
.
backend
==
"hf"
:
elif
args
.
backend
==
"hf"
:
if
args
.
hf_max_batch_size
is
None
:
if
args
.
hf_max_batch_size
is
None
:
raise
ValueError
(
"HF max batch size is required for HF backend."
)
raise
ValueError
(
"HF max batch size is required for HF backend."
)
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
main
(
args
)
main
(
args
)
vllm/config.py
View file @
4338cc47
...
@@ -16,6 +16,7 @@ class ModelConfig:
...
@@ -16,6 +16,7 @@ class ModelConfig:
Args:
Args:
model: Name or path of the huggingface model to use.
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
download_dir: Directory to download and load the weights, default to the
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
use_np_weights: Save a numpy copy of model weights for faster loading.
...
@@ -30,6 +31,7 @@ class ModelConfig:
...
@@ -30,6 +31,7 @@ class ModelConfig:
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
],
download_dir
:
Optional
[
str
],
download_dir
:
Optional
[
str
],
use_np_weights
:
bool
,
use_np_weights
:
bool
,
use_dummy_weights
:
bool
,
use_dummy_weights
:
bool
,
...
@@ -37,6 +39,7 @@ class ModelConfig:
...
@@ -37,6 +39,7 @@ class ModelConfig:
seed
:
int
,
seed
:
int
,
)
->
None
:
)
->
None
:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
download_dir
=
download_dir
self
.
download_dir
=
download_dir
self
.
use_np_weights
=
use_np_weights
self
.
use_np_weights
=
use_np_weights
self
.
use_dummy_weights
=
use_dummy_weights
self
.
use_dummy_weights
=
use_dummy_weights
...
...
vllm/engine/arg_utils.py
View file @
4338cc47
...
@@ -11,6 +11,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
...
@@ -11,6 +11,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
class
EngineArgs
:
class
EngineArgs
:
"""Arguments for vLLM engine."""
"""Arguments for vLLM engine."""
model
:
str
model
:
str
tokenizer
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
use_np_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
...
@@ -27,6 +28,8 @@ class EngineArgs:
...
@@ -27,6 +28,8 @@ class EngineArgs:
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
self
.
tokenizer
=
self
.
model
self
.
max_num_seqs
=
min
(
self
.
max_num_seqs
,
self
.
max_num_batched_tokens
)
self
.
max_num_seqs
=
min
(
self
.
max_num_seqs
,
self
.
max_num_batched_tokens
)
@
staticmethod
@
staticmethod
...
@@ -37,6 +40,8 @@ class EngineArgs:
...
@@ -37,6 +40,8 @@ class EngineArgs:
# Model arguments
# Model arguments
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'facebook/opt-125m'
,
help
=
'name or path of the huggingface model to use'
)
help
=
'name or path of the huggingface model to use'
)
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
EngineArgs
.
tokenizer
,
help
=
'name or path of the huggingface tokenizer to use'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
EngineArgs
.
download_dir
,
default
=
EngineArgs
.
download_dir
,
help
=
'directory to download and load the weights, '
help
=
'directory to download and load the weights, '
...
@@ -104,7 +109,7 @@ class EngineArgs:
...
@@ -104,7 +109,7 @@ class EngineArgs:
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Initialize the configs.
# Initialize the configs.
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
self
.
model
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
model
,
self
.
tokenizer
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
seed
)
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
seed
)
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
self
.
swap_space
)
...
...
vllm/engine/llm_engine.py
View file @
4338cc47
...
@@ -6,11 +6,12 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
...
@@ -6,11 +6,12 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.ray_utils
import
DeviceID
,
initialize_cluster
,
ray
from
vllm.engine.ray_utils
import
DeviceID
,
initialize_cluster
,
ray
from
vllm.engine.tokenizer_utils
import
detokenize_incrementally
,
get_tokenizer
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
...
@@ -59,6 +60,7 @@ class LLMEngine:
...
@@ -59,6 +60,7 @@ class LLMEngine:
logger
.
info
(
logger
.
info
(
"Initializing an LLM engine with config: "
"Initializing an LLM engine with config: "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"tokenizer=
{
model_config
.
tokenizer
!
r
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
...
@@ -75,7 +77,7 @@ class LLMEngine:
...
@@ -75,7 +77,7 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
_verify_args
()
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
model_config
.
model
)
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
)
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
# Create the parallel GPU workers.
...
...
vllm/entrypoints/llm.py
View file @
4338cc47
...
@@ -25,6 +25,7 @@ class LLM:
...
@@ -25,6 +25,7 @@ class LLM:
Args:
Args:
model: The name or path of a HuggingFace Transformers model.
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
dtype: The data type for the model weights and activations. Currently,
...
@@ -38,6 +39,7 @@ class LLM:
...
@@ -38,6 +39,7 @@ class LLM:
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"auto"
,
dtype
:
str
=
"auto"
,
seed
:
int
=
0
,
seed
:
int
=
0
,
...
@@ -47,6 +49,7 @@ class LLM:
...
@@ -47,6 +49,7 @@ class LLM:
kwargs
[
"disable_log_stats"
]
=
True
kwargs
[
"disable_log_stats"
]
=
True
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model
,
model
=
model
,
tokenizer
=
tokenizer
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
dtype
=
dtype
,
seed
=
seed
,
seed
=
seed
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
4338cc47
...
@@ -15,7 +15,6 @@ import uvicorn
...
@@ -15,7 +15,6 @@ import uvicorn
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.tokenizer_utils
import
get_tokenizer
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionRequest
,
CompletionResponse
,
CompletionResponseChoice
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
ErrorResponse
,
...
@@ -23,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -23,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
...
vllm/transformers_utils/__init__.py
0 → 100644
View file @
4338cc47
vllm/
engine
/tokenizer
_utils
.py
→
vllm/
transformers_utils
/tokenizer.py
View file @
4338cc47
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
from
transformers
import
(
AutoConfig
,
AutoTokenizer
,
PreTrainedTokenizer
,
from
transformers
import
(
AutoTokenizer
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
)
PreTrainedTokenizerFast
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_MODEL_TYPES_WITH_SLOW_TOKENIZER
=
[]
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER
=
"hf-internal-testing/llama-tokenizer"
def
get_tokenizer
(
def
get_tokenizer
(
model
_name
:
str
,
tokenizer
_name
:
str
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface."""
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
"llama"
in
tokenizer_name
.
lower
()
and
kwargs
.
get
(
"use_fast"
,
True
):
if
"open_llama"
in
model_name
:
kwargs
[
"use_fast"
]
=
False
logger
.
info
(
logger
.
info
(
"OpenLLaMA models do not support the fast tokenizer. "
"For some LLaMA-based models, initializing the fast tokenizer may "
"Using the slow tokenizer instead."
)
"take a long time. To eliminate the initialization time, consider "
elif
config
.
model_type
==
"llama"
and
kwargs
.
get
(
"use_fast"
,
True
):
f
"using '
{
_FAST_LLAMA_TOKENIZER
}
' instead of the original "
# LLaMA fast tokenizer causes protobuf errors in some environments.
"tokenizer."
)
# However, we found that the below LLaMA fast tokenizer works well in
try
:
# most environments.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
*
args
,
model_name
=
"hf-internal-testing/llama-tokenizer"
**
kwargs
)
logger
.
info
(
except
TypeError
as
e
:
f
"Using the LLaMA fast tokenizer in '
{
model_name
}
' to avoid "
# The LLaMA tokenizer causes a protobuf error in some environments.
"potential protobuf errors."
)
err_msg
=
(
elif
config
.
model_type
in
_MODEL_TYPES_WITH_SLOW_TOKENIZER
:
"Failed to load the tokenizer. If you are using a LLaMA-based "
if
kwargs
.
get
(
"use_fast"
,
False
)
==
True
:
f
"model, use '
{
_FAST_LLAMA_TOKENIZER
}
' instead of the original "
raise
ValueError
(
"tokenizer."
)
f
"Cannot use the fast tokenizer for
{
config
.
model_type
}
due to "
raise
RuntimeError
(
err_msg
)
from
e
"bugs in the fast tokenizer."
)
logger
.
info
(
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
f
"Using the slow tokenizer for
{
config
.
model_type
}
due to bugs in "
logger
.
warning
(
"the fast tokenizer. This could potentially lead to performance "
"Using a slow tokenizer. This might cause a significant "
"degradation."
)
"slowdown. Consider using a fast tokenizer instead."
)
kwargs
[
"use_fast"
]
=
False
return
tokenizer
return
AutoTokenizer
.
from_pretrained
(
model_name
,
*
args
,
**
kwargs
)
def
detokenize_incrementally
(
def
detokenize_incrementally
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment