Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
a945fcc2
Unverified
Commit
a945fcc2
authored
Jul 07, 2023
by
codethazine
Committed by
GitHub
Jul 07, 2023
Browse files
Add trust-remote-code flag to handle remote tokenizers (#364)
parent
be54f8e5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
6 deletions
+39
-6
vllm/config.py
vllm/config.py
+4
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-0
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+19
-2
No files found.
vllm/config.py
View file @
a945fcc2
...
@@ -20,6 +20,8 @@ class ModelConfig:
...
@@ -20,6 +20,8 @@ class ModelConfig:
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
default cache directory of huggingface.
use_np_weights: Save a numpy copy of model weights for faster loading.
use_np_weights: Save a numpy copy of model weights for faster loading.
...
@@ -36,6 +38,7 @@ class ModelConfig:
...
@@ -36,6 +38,7 @@ class ModelConfig:
model
:
str
,
model
:
str
,
tokenizer
:
str
,
tokenizer
:
str
,
tokenizer_mode
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
download_dir
:
Optional
[
str
],
download_dir
:
Optional
[
str
],
use_np_weights
:
bool
,
use_np_weights
:
bool
,
use_dummy_weights
:
bool
,
use_dummy_weights
:
bool
,
...
@@ -45,6 +48,7 @@ class ModelConfig:
...
@@ -45,6 +48,7 @@ class ModelConfig:
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
self
.
trust_remote_code
=
trust_remote_code
self
.
download_dir
=
download_dir
self
.
download_dir
=
download_dir
self
.
use_np_weights
=
use_np_weights
self
.
use_np_weights
=
use_np_weights
self
.
use_dummy_weights
=
use_dummy_weights
self
.
use_dummy_weights
=
use_dummy_weights
...
...
vllm/engine/arg_utils.py
View file @
a945fcc2
...
@@ -13,6 +13,7 @@ class EngineArgs:
...
@@ -13,6 +13,7 @@ class EngineArgs:
model
:
str
model
:
str
tokenizer
:
Optional
[
str
]
=
None
tokenizer
:
Optional
[
str
]
=
None
tokenizer_mode
:
str
=
'auto'
tokenizer_mode
:
str
=
'auto'
trust_remote_code
:
bool
=
False
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
use_np_weights
:
bool
=
False
use_np_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
use_dummy_weights
:
bool
=
False
...
@@ -55,6 +56,9 @@ class EngineArgs:
...
@@ -55,6 +56,9 @@ class EngineArgs:
help
=
'tokenizer mode. "auto" will use the fast '
help
=
'tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.'
)
'always use the slow tokenizer.'
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'trust remote code from huggingface'
)
parser
.
add_argument
(
'--download-dir'
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
type
=
str
,
default
=
EngineArgs
.
download_dir
,
default
=
EngineArgs
.
download_dir
,
...
@@ -141,9 +145,10 @@ class EngineArgs:
...
@@ -141,9 +145,10 @@ class EngineArgs:
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
)
->
Tuple
[
ModelConfig
,
CacheConfig
,
ParallelConfig
,
SchedulerConfig
]:
# Initialize the configs.
# Initialize the configs.
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
self
.
download_dir
,
self
.
tokenizer_mode
,
self
.
trust_remote_code
,
self
.
use_np_weights
,
self
.
use_dummy_weights
,
self
.
download_dir
,
self
.
use_np_weights
,
self
.
dtype
,
self
.
seed
)
self
.
use_dummy_weights
,
self
.
dtype
,
self
.
seed
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
)
self
.
swap_space
)
...
...
vllm/engine/llm_engine.py
View file @
a945fcc2
...
@@ -62,6 +62,7 @@ class LLMEngine:
...
@@ -62,6 +62,7 @@ class LLMEngine:
f
"model=
{
model_config
.
model
!
r
}
, "
f
"model=
{
model_config
.
model
!
r
}
, "
f
"tokenizer=
{
model_config
.
tokenizer
!
r
}
, "
f
"tokenizer=
{
model_config
.
tokenizer
!
r
}
, "
f
"tokenizer_mode=
{
model_config
.
tokenizer_mode
}
, "
f
"tokenizer_mode=
{
model_config
.
tokenizer_mode
}
, "
f
"trust_remote_code=
{
model_config
.
trust_remote_code
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"dtype=
{
model_config
.
dtype
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
f
"use_dummy_weights=
{
model_config
.
use_dummy_weights
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
f
"download_dir=
{
model_config
.
download_dir
!
r
}
, "
...
@@ -78,7 +79,9 @@ class LLMEngine:
...
@@ -78,7 +79,9 @@ class LLMEngine:
self
.
_verify_args
()
self
.
_verify_args
()
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer
=
get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
)
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
,
trust_remote_code
=
model_config
.
trust_remote_code
)
self
.
seq_counter
=
Counter
()
self
.
seq_counter
=
Counter
()
# Create the parallel GPU workers.
# Create the parallel GPU workers.
...
...
vllm/entrypoints/llm.py
View file @
a945fcc2
...
@@ -28,6 +28,8 @@ class LLM:
...
@@ -28,6 +28,8 @@ class LLM:
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
dtype: The data type for the model weights and activations. Currently,
...
@@ -43,6 +45,7 @@ class LLM:
...
@@ -43,6 +45,7 @@ class LLM:
model
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer_mode
:
str
=
"auto"
,
tokenizer_mode
:
str
=
"auto"
,
trust_remote_code
:
bool
=
False
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
dtype
:
str
=
"auto"
,
dtype
:
str
=
"auto"
,
seed
:
int
=
0
,
seed
:
int
=
0
,
...
@@ -54,6 +57,7 @@ class LLM:
...
@@ -54,6 +57,7 @@ class LLM:
model
=
model
,
model
=
model
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
tokenizer_mode
=
tokenizer_mode
,
tokenizer_mode
=
tokenizer_mode
,
trust_remote_code
=
trust_remote_code
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
dtype
=
dtype
,
dtype
=
dtype
,
seed
=
seed
,
seed
=
seed
,
...
...
vllm/transformers_utils/tokenizer.py
View file @
a945fcc2
...
@@ -15,6 +15,7 @@ def get_tokenizer(
...
@@ -15,6 +15,7 @@ def get_tokenizer(
tokenizer_name
:
str
,
tokenizer_name
:
str
,
*
args
,
*
args
,
tokenizer_mode
:
str
=
"auto"
,
tokenizer_mode
:
str
=
"auto"
,
trust_remote_code
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface."""
...
@@ -31,8 +32,11 @@ def get_tokenizer(
...
@@ -31,8 +32,11 @@ def get_tokenizer(
f
"using '
{
_FAST_LLAMA_TOKENIZER
}
' instead of the original "
f
"using '
{
_FAST_LLAMA_TOKENIZER
}
' instead of the original "
"tokenizer."
)
"tokenizer."
)
try
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
*
args
,
tokenizer
=
AutoTokenizer
.
from_pretrained
(
**
kwargs
)
tokenizer_name
,
trust_remote_code
=
trust_remote_code
,
*
args
,
**
kwargs
)
except
TypeError
as
e
:
except
TypeError
as
e
:
# The LLaMA tokenizer causes a protobuf error in some environments.
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg
=
(
err_msg
=
(
...
@@ -40,6 +44,19 @@ def get_tokenizer(
...
@@ -40,6 +44,19 @@ def get_tokenizer(
f
"model, use '
{
_FAST_LLAMA_TOKENIZER
}
' instead of the original "
f
"model, use '
{
_FAST_LLAMA_TOKENIZER
}
' instead of the original "
"tokenizer."
)
"tokenizer."
)
raise
RuntimeError
(
err_msg
)
from
e
raise
RuntimeError
(
err_msg
)
from
e
except
ValueError
as
e
:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if
(
e
is
not
None
and
(
"does not exist or is not currently imported."
in
str
(
e
)
or
"requires you to execute the tokenizer file"
in
str
(
e
))):
err_msg
=
(
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider using the --trust-remote-code flag."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
logger
.
warning
(
logger
.
warning
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment