Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6aa8f9a4
Unverified
Commit
6aa8f9a4
authored
Jun 01, 2025
by
Cyrus Leung
Committed by
GitHub
Jun 01, 2025
Browse files
[Core] Rework dtype resolution (#18751)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
1bc86a3d
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
314 additions
and
119 deletions
+314
-119
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+1
-4
tests/conftest.py
tests/conftest.py
+6
-1
tests/models/language/pooling/mteb_utils.py
tests/models/language/pooling/mteb_utils.py
+4
-7
tests/models/language/pooling/test_classification.py
tests/models/language/pooling/test_classification.py
+1
-1
tests/models/language/pooling/test_embedding.py
tests/models/language/pooling/test_embedding.py
+1
-5
tests/models/multimodal/generation/test_whisper.py
tests/models/multimodal/generation/test_whisper.py
+1
-0
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+1
-1
tests/samplers/test_no_bad_words.py
tests/samplers/test_no_bad_words.py
+1
-1
tests/test_utils.py
tests/test_utils.py
+80
-22
vllm/config.py
vllm/config.py
+135
-67
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+1
-1
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+33
-7
vllm/utils.py
vllm/utils.py
+49
-2
No files found.
tests/basic_correctness/test_basic_correctness.py
View file @
6aa8f9a4
...
...
@@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs(
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"enable_prompt_embeds"
,
[
True
,
False
])
...
...
@@ -69,7 +68,6 @@ def test_models(
hf_runner
,
model
:
str
,
backend
:
str
,
dtype
:
str
,
max_tokens
:
int
,
enforce_eager
:
bool
,
enable_prompt_embeds
:
bool
,
...
...
@@ -97,7 +95,7 @@ def test_models(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
example_prompts
=
[
prompt
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
if
enable_prompt_embeds
:
with
torch
.
no_grad
():
...
...
@@ -106,7 +104,6 @@ def test_models(
with
VllmRunner
(
model
,
max_model_len
=
8192
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
enable_prompt_embeds
=
enable_prompt_embeds
,
gpu_memory_utilization
=
0.7
)
as
vllm_model
:
...
...
tests/conftest.py
View file @
6aa8f9a4
...
...
@@ -324,7 +324,12 @@ class HfRunner:
trust_remote_code
=
trust_remote_code
,
)
self
.
device
=
self
.
get_default_device
()
self
.
dtype
=
torch_dtype
=
_get_and_verify_dtype
(
self
.
config
,
dtype
)
self
.
dtype
=
torch_dtype
=
_get_and_verify_dtype
(
self
.
model_name
,
self
.
config
,
dtype
=
dtype
,
is_pooling_model
=
is_sentence_transformer
or
is_cross_encoder
,
)
model_kwargs
=
model_kwargs
if
model_kwargs
is
not
None
else
{}
model_kwargs
.
setdefault
(
"torch_dtype"
,
torch_dtype
)
...
...
tests/models/language/pooling/mteb_utils.py
View file @
6aa8f9a4
...
...
@@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner,
vllm_main_score
=
run_mteb_embed_task
(
VllmMtebEncoder
(
vllm_model
),
MTEB_EMBED_TASKS
)
vllm_dtype
=
vllm_model
.
model
.
llm_engine
.
model_config
.
dtype
model_dtype
=
getattr
(
vllm_model
.
model
.
llm_engine
.
model_config
.
hf_config
,
"torch_dtype"
,
vllm_dtype
)
with
set_default_torch_dtype
(
model
_dtype
)
and
hf_runner
(
with
set_default_torch_dtype
(
vllm
_dtype
)
and
hf_runner
(
model_info
.
name
,
is_sentence_transformer
=
True
,
dtype
=
model
_dtype
)
as
hf_model
:
dtype
=
vllm
_dtype
)
as
hf_model
:
if
hf_model_callback
is
not
None
:
hf_model_callback
(
hf_model
)
st_main_score
=
run_mteb_embed_task
(
hf_model
,
MTEB_EMBED_TASKS
)
print
(
"VLLM:"
,
vllm_dtype
,
vllm_main_score
)
print
(
"SentenceTransformer:"
,
model_dtype
,
st_main_score
)
print
(
"VLLM:"
,
vllm_main_score
)
print
(
"SentenceTransformer
s
:"
,
st_main_score
)
print
(
"Difference:"
,
st_main_score
-
vllm_main_score
)
assert
st_main_score
==
pytest
.
approx
(
vllm_main_score
,
abs
=
MTEB_EMBED_TOL
)
tests/models/language/pooling/test_classification.py
View file @
6aa8f9a4
...
...
@@ -43,6 +43,6 @@ def test_models(
# the tolerance value of 1e-2 is selected based on the
# half datatype tests in
# tests/models/
embedding/
language/test_embedding.py
# tests/models/language
/pooling
/test_embedding.py
assert
torch
.
allclose
(
hf_output
,
vllm_output
,
1e-3
if
dtype
==
"float"
else
1e-2
)
tests/models/language/pooling/test_embedding.py
View file @
6aa8f9a4
...
...
@@ -30,13 +30,11 @@ from ...utils import check_embeddings_close
pytest
.
param
(
"sentence-transformers/stsb-roberta-base-v2"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
,
dtype
:
str
,
monkeypatch
,
)
->
None
:
...
...
@@ -58,13 +56,11 @@ def test_models(
# So we need to strip the input texts to avoid test failing.
example_prompts
=
[
str
(
s
).
strip
()
for
s
in
example_prompts
]
with
hf_runner
(
model
,
dtype
=
dtype
,
is_sentence_transformer
=
True
)
as
hf_model
:
with
hf_runner
(
model
,
is_sentence_transformer
=
True
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
with
vllm_runner
(
model
,
task
=
"embed"
,
dtype
=
dtype
,
max_model_len
=
None
,
**
vllm_extra_kwargs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
...
...
tests/models/multimodal/generation/test_whisper.py
View file @
6aa8f9a4
...
...
@@ -100,6 +100,7 @@ def run_test(
with
vllm_runner
(
model
,
dtype
=
"half"
,
max_model_len
=
448
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
...
...
tests/models/multimodal/processing/test_common.py
View file @
6aa8f9a4
...
...
@@ -40,7 +40,7 @@ def _test_processing_correctness(
tokenizer_mode
=
model_info
.
tokenizer_mode
,
trust_remote_code
=
model_info
.
trust_remote_code
,
seed
=
0
,
dtype
=
"
float16
"
,
dtype
=
"
auto
"
,
revision
=
None
,
hf_overrides
=
model_info
.
hf_overrides
,
)
...
...
tests/samplers/test_no_bad_words.py
View file @
6aa8f9a4
...
...
@@ -103,7 +103,7 @@ class TestTwoTokenBadWord:
add_special_tokens
=
False
)[
0
]
def
test_two_token_bad_word
(
self
,
vllm_runner
):
with
vllm_runner
(
self
.
MODEL
)
as
llm
:
with
vllm_runner
(
self
.
MODEL
,
dtype
=
"half"
)
as
llm
:
output_token_ids
=
self
.
_generate
(
llm
)
assert
output_token_ids
[:
2
]
==
[
self
.
target_token_id1
,
self
.
target_token_id2
...
...
tests/test_utils.py
View file @
6aa8f9a4
...
...
@@ -17,7 +17,8 @@ from vllm_test_utils.monitor import monitor
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.utils
import
(
CacheInfo
,
FlexibleArgumentParser
,
LRUCache
,
MemorySnapshot
,
PlaceholderModule
,
StoreBoolean
,
bind_kv_cache
,
deprecate_kwargs
,
get_open_port
,
bind_kv_cache
,
common_broadcastable_dtype
,
deprecate_kwargs
,
get_open_port
,
is_lossless_cast
,
make_zmq_path
,
make_zmq_socket
,
memory_profiling
,
merge_async_iterators
,
sha256
,
split_zmq_path
,
supports_kw
,
swap_dict_values
)
...
...
@@ -567,12 +568,65 @@ def test_lru_cache():
assert
6
in
cache
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"src_dtype"
,
"tgt_dtype"
,
"expected_result"
),
[
# Different precision_levels
(
torch
.
bool
,
torch
.
int8
,
True
),
(
torch
.
bool
,
torch
.
float16
,
True
),
(
torch
.
bool
,
torch
.
complex32
,
True
),
(
torch
.
int64
,
torch
.
bool
,
False
),
(
torch
.
int64
,
torch
.
float16
,
True
),
(
torch
.
int64
,
torch
.
complex32
,
True
),
(
torch
.
float64
,
torch
.
bool
,
False
),
(
torch
.
float64
,
torch
.
int8
,
False
),
(
torch
.
float64
,
torch
.
complex32
,
True
),
(
torch
.
complex128
,
torch
.
bool
,
False
),
(
torch
.
complex128
,
torch
.
int8
,
False
),
(
torch
.
complex128
,
torch
.
float16
,
False
),
# precision_level=0
(
torch
.
bool
,
torch
.
bool
,
True
),
# precision_level=1
(
torch
.
int8
,
torch
.
int16
,
True
),
(
torch
.
int16
,
torch
.
int8
,
False
),
(
torch
.
uint8
,
torch
.
int8
,
False
),
(
torch
.
int8
,
torch
.
uint8
,
False
),
# precision_level=2
(
torch
.
float16
,
torch
.
float32
,
True
),
(
torch
.
float32
,
torch
.
float16
,
False
),
(
torch
.
bfloat16
,
torch
.
float32
,
True
),
(
torch
.
float32
,
torch
.
bfloat16
,
False
),
# precision_level=3
(
torch
.
complex32
,
torch
.
complex64
,
True
),
(
torch
.
complex64
,
torch
.
complex32
,
False
),
],
)
# yapf: enable
def
test_is_lossless_cast
(
src_dtype
,
tgt_dtype
,
expected_result
):
assert
is_lossless_cast
(
src_dtype
,
tgt_dtype
)
==
expected_result
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"dtypes"
,
"expected_result"
),
[
([
torch
.
bool
],
torch
.
bool
),
([
torch
.
bool
,
torch
.
int8
],
torch
.
int8
),
([
torch
.
bool
,
torch
.
int8
,
torch
.
float16
],
torch
.
float16
),
([
torch
.
bool
,
torch
.
int8
,
torch
.
float16
,
torch
.
complex32
],
torch
.
complex32
),
# noqa: E501
],
)
# yapf: enable
def
test_common_broadcastable_dtype
(
dtypes
,
expected_result
):
assert
common_broadcastable_dtype
(
dtypes
)
==
expected_result
def
test_placeholder_module_error_handling
():
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
def
build_ctx
():
return
pytest
.
raises
(
ModuleNotFoundError
,
match
=
"No module named"
)
return
pytest
.
raises
(
ModuleNotFoundError
,
match
=
"No module named"
)
with
build_ctx
():
int
(
placeholder
)
...
...
@@ -608,6 +662,7 @@ def test_placeholder_module_error_handling():
_
=
placeholder_attr
.
module
# yapf: disable
@
pytest
.
mark
.
parametrize
(
"obj,key1,key2"
,
[
...
...
@@ -618,6 +673,7 @@ def test_placeholder_module_error_handling():
# Tests for both keys do not exist
({
1
:
"a"
,
2
:
"b"
},
3
,
4
),
])
# yapf: enable
def
test_swap_dict_values
(
obj
,
key1
,
key2
):
original_obj
=
obj
.
copy
()
swap_dict_values
(
obj
,
key1
,
key2
)
...
...
@@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2):
assert
key1
not
in
obj
def
test_model_specification
(
parser_with_config
,
cli_config_file
,
def
test_model_specification
(
parser_with_config
,
cli_config_file
,
cli_config_file_with_model
):
# Test model in CLI takes precedence over config
args
=
parser_with_config
.
parse_args
([
'serve'
,
'cli-model'
,
'--config'
,
cli_config_file_with_model
])
args
=
parser_with_config
.
parse_args
(
[
'serve'
,
'cli-model'
,
'--config'
,
cli_config_file_with_model
])
assert
args
.
model_tag
==
'cli-model'
assert
args
.
served_model_name
==
'mymodel'
# Test model from config file works
args
=
parser_with_config
.
parse_args
([
'serve'
,
'--config'
,
cli_config_file_with_model
,
'serve'
,
'--config'
,
cli_config_file_with_model
,
])
assert
args
.
model
==
'config-model'
assert
args
.
served_model_name
==
'mymodel'
...
...
@@ -655,16 +711,18 @@ def test_model_specification(parser_with_config,
# Test using --model option raises error
with
pytest
.
raises
(
ValueError
,
match
=
(
"With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."
),
match
=
(
"With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."
),
):
parser_with_config
.
parse_args
([
'serve'
,
'--model'
,
'my-model'
])
# Test other config values are preserved
args
=
parser_with_config
.
parse_args
([
'serve'
,
'cli-model'
,
'--config'
,
cli_config_file_with_model
,
'serve'
,
'cli-model'
,
'--config'
,
cli_config_file_with_model
,
])
assert
args
.
tensor_parallel_size
==
2
assert
args
.
trust_remote_code
is
True
...
...
@@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int):
assert
hash
!=
0
bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
assert
hash
==
int
.
from_bytes
(
hashlib
.
sha256
(
bytes
).
digest
(),
byteorder
=
"big"
)
assert
hash
==
int
.
from_bytes
(
hashlib
.
sha256
(
bytes
).
digest
(),
byteorder
=
"big"
)
# hashing again, returns the same value
assert
hash
==
sha256
(
input
)
...
...
@@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int):
(
"tcp://127.0.0.1:5555"
,
(
"tcp"
,
"127.0.0.1"
,
"5555"
)),
(
"tcp://[::1]:5555"
,
(
"tcp"
,
"::1"
,
"5555"
)),
# IPv6 address
(
"inproc://some_identifier"
,
(
"inproc"
,
"some_identifier"
,
""
)),
]
)
])
def
test_split_zmq_path
(
path
,
expected
):
assert
split_zmq_path
(
path
)
==
expected
...
...
@@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected):
"tcp://127.0.0.1"
,
# Missing port
"tcp://[::1]"
,
# Missing port for IPv6
"tcp://:5555"
,
# Missing host
]
)
])
def
test_split_zmq_path_invalid
(
invalid_path
):
with
pytest
.
raises
(
ValueError
):
split_zmq_path
(
invalid_path
)
...
...
@@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6():
zsock
:
zmq
.
Socket
=
make_zmq_socket
(
ctx
,
ipv6_path
,
socket_type
)
# Verify that the IPV6 option is set
assert
zsock
.
getsockopt
(
zmq
.
IPV6
)
==
1
,
"IPV6 option should be enabled for IPv6 addresses"
assert
zsock
.
getsockopt
(
zmq
.
IPV6
)
==
1
,
"IPV6 option should be enabled for IPv6 addresses"
# Clean up
zsock
.
close
()
...
...
vllm/config.py
View file @
6aa8f9a4
...
...
@@ -24,6 +24,7 @@ import torch
from
pydantic
import
(
ConfigDict
,
SkipValidation
,
TypeAdapter
,
field_validator
,
model_validator
)
from
pydantic.dataclasses
import
dataclass
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
transformers
import
PretrainedConfig
from
typing_extensions
import
deprecated
,
runtime_checkable
...
...
@@ -42,15 +43,16 @@ from vllm.transformers_utils.config import (
ConfigFormat
,
get_config
,
get_hf_image_processor_config
,
get_hf_text_config
,
get_pooling_config
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
try_get_generation_config
,
uses_mrope
)
try_get_generation_config
,
try_get_safetensors_metadata
,
uses_mrope
)
from
vllm.transformers_utils.s3_utils
import
S3Model
from
vllm.transformers_utils.utils
import
is_s3
,
maybe_model_redirect
from
vllm.utils
import
(
DEFAULT_MAX_NUM_BATCHED_TOKENS
,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS
,
GiB_bytes
,
LayerBlockType
,
cuda_device_count_stateless
,
get_cpu_memory
,
get_open_port
,
is_torch_equal_or_newer
,
random_uuid
,
resolve_obj_by_qualname
)
LayerBlockType
,
common_broadcastable_dtype
,
cuda_device_count_stateless
,
get_cpu_memory
,
get_open_port
,
is_torch_equal_or_newer
,
random_uuid
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
from
_typeshed
import
DataclassInstance
...
...
@@ -540,7 +542,24 @@ class ModelConfig:
self
.
encoder_config
=
self
.
_get_encoder_config
()
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
model
,
hf_token
=
self
.
hf_token
,
revision
=
self
.
revision
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
self
.
dtype
)
supported_tasks
,
task
=
self
.
_resolve_task
(
self
.
task
)
self
.
supported_tasks
=
supported_tasks
self
.
task
=
task
if
self
.
task
in
(
"draft"
,
"generate"
):
self
.
truncation_side
=
"left"
else
:
self
.
truncation_side
=
"right"
self
.
pooler_config
=
self
.
_init_pooler_config
()
self
.
dtype
=
_get_and_verify_dtype
(
self
.
model
,
self
.
hf_config
,
self
.
dtype
,
is_pooling_model
=
self
.
runner_type
==
"pooling"
,
revision
=
self
.
revision
,
)
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
...
...
@@ -597,16 +616,6 @@ class ModelConfig:
raise
ValueError
(
"`override_neuron_config` is only supported on Neuron."
)
supported_tasks
,
task
=
self
.
_resolve_task
(
self
.
task
)
self
.
supported_tasks
=
supported_tasks
self
.
task
=
task
if
self
.
task
in
(
"draft"
,
"generate"
):
self
.
truncation_side
=
"left"
else
:
self
.
truncation_side
=
"right"
self
.
pooler_config
=
self
.
_init_pooler_config
()
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
self
.
_verify_bnb_config
()
...
...
@@ -692,7 +701,6 @@ class ModelConfig:
self
.
model
,
self
.
revision
)
def
_init_pooler_config
(
self
)
->
Optional
[
"PoolerConfig"
]:
if
self
.
runner_type
==
"pooling"
:
if
isinstance
(
self
.
override_pooler_config
,
dict
):
self
.
override_pooler_config
=
PoolerConfig
(
...
...
@@ -3074,13 +3082,37 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16"
:
torch
.
bfloat16
,
}
_ROCM_NOT_SUPPORTED_DTYPE
:
list
[
str
]
=
[]
#
# model_type -> reason
_FLOAT16_NOT_SUPPORTED_MODELS
=
{
"gemma2"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"gemma3"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"plamo2"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
"glm4"
:
"Numerical instability. Please use bfloat16 or float32 instead."
,
}
def
_get_and_verify_dtype
(
def
_is_valid_dtype
(
model_type
:
str
,
dtype
:
torch
.
dtype
):
if
model_type
in
_FLOAT16_NOT_SUPPORTED_MODELS
and
dtype
==
torch
.
float16
:
# noqa: E501, SIM103
return
False
return
True
def
_check_valid_dtype
(
model_type
:
str
,
dtype
:
torch
.
dtype
):
if
model_type
in
_FLOAT16_NOT_SUPPORTED_MODELS
and
dtype
==
torch
.
float16
:
reason
=
_FLOAT16_NOT_SUPPORTED_MODELS
[
model_type
]
raise
ValueError
(
f
"The model type
{
model_type
!
r
}
"
f
"does not support float16. Reason:
{
reason
}
"
)
return
True
def
_find_dtype
(
model_id
:
str
,
config
:
PretrainedConfig
,
dtype
:
Union
[
str
,
torch
.
dtype
],
)
->
torch
.
dtype
:
*
,
revision
:
Optional
[
str
],
):
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
)
...
...
@@ -3092,75 +3124,111 @@ def _get_and_verify_dtype(
if
config_dtype
is
None
and
hasattr
(
config
,
"vision_config"
):
config_dtype
=
getattr
(
config
.
vision_config
,
"torch_dtype"
,
None
)
# Try to read the dtype of the weights if they are in safetensors format
if
config_dtype
is
None
:
repo_mt
=
try_get_safetensors_metadata
(
model_id
,
revision
=
revision
)
if
repo_mt
and
(
files_mt
:
=
repo_mt
.
files_metadata
):
param_dtypes
:
set
[
torch
.
dtype
]
=
{
_SAFETENSORS_TO_TORCH_DTYPE
[
dtype_str
]
for
file_mt
in
files_mt
.
values
()
for
dtype_str
in
file_mt
.
parameter_count
if
dtype_str
in
_SAFETENSORS_TO_TORCH_DTYPE
}
if
param_dtypes
:
return
common_broadcastable_dtype
(
param_dtypes
)
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
if
isinstance
(
dtype
,
str
):
dtype
=
dtype
.
lower
()
if
dtype
==
"auto"
:
# Set default dtype from model config
if
config_dtype
==
torch
.
float32
:
# Following common practice, we use float16 for float32 models
torch_dtype
=
torch
.
float16
else
:
torch_dtype
=
config_dtype
return
config_dtype
if
config
.
model_type
==
"plamo2"
:
logger
.
warning
(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype
=
torch
.
bfloat16
# Deal with torch dtype fallback for device compatibility.
def
_resolve_auto_dtype
(
model_type
:
str
,
config_dtype
:
torch
.
dtype
,
*
,
is_pooling_model
:
bool
,
):
from
vllm.platforms
import
current_platform
if
torch_dtype
not
in
current_platform
.
supported_dtypes
:
supported_dtypes
=
[
dtype
for
dtype
in
current_platform
.
supported_dtypes
if
_is_valid_dtype
(
model_type
,
dtype
)
]
if
is_pooling_model
and
torch
.
float16
in
supported_dtypes
:
preferred_dtype
=
torch
.
float16
else
:
preferred_dtype
=
supported_dtypes
[
0
]
# Downcast for float32 models
if
config_dtype
==
torch
.
float32
:
config_dtype
=
preferred_dtype
if
config_dtype
in
supported_dtypes
:
return
config_dtype
# Ensure device compatibility
device_name
=
current_platform
.
get_device_name
()
device_capability
=
current_platform
.
get_device_capability
()
if
((
capability
:
=
current_platform
.
get_device_capability
())
is
None
):
compute_str
=
""
if
device_capability
is
None
:
device_str
=
f
"
{
device_name
!
r
}
"
else
:
version_str
=
capability
.
as_version_str
()
compute_str
=
f
"
(with compute capability
{
version_str
}
)"
fallback_dtype
=
current_platform
.
supported_dtypes
[
0
]
version_str
=
device_
capability
.
as_version_str
()
device_str
=
f
"
{
device_name
!
r
}
(with compute capability
{
version_str
}
)"
logger
.
warning
(
"Your
%s
device%s doesn't support %s. "
\
"Your device
%s doesn't support %s. "
"Falling back to %s for compatibility."
,
device_name
,
compute_str
,
torch_dtype
,
fallback_dtype
device_str
,
config_dtype
,
preferred_dtype
,
)
torch_dtype
=
fallback_dtype
if
current_platform
.
is_hpu
()
and
torch_dtype
==
torch
.
float16
:
logger
.
warning
(
"For HPU, we cast models to bfloat16 instead of "
"using float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype
=
torch
.
bfloat16
elif
dtype
==
"float16"
and
config
.
model_type
==
"plamo2"
:
logger
.
warning
(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead."
)
torch_dtype
=
torch
.
float16
return
preferred_dtype
def
_get_and_verify_dtype
(
model_id
:
str
,
config
:
PretrainedConfig
,
dtype
:
Union
[
str
,
torch
.
dtype
],
*
,
is_pooling_model
:
bool
,
revision
:
Optional
[
str
]
=
None
,
)
->
torch
.
dtype
:
config_dtype
=
_find_dtype
(
model_id
,
config
,
revision
=
revision
)
model_type
=
config
.
model_type
if
isinstance
(
dtype
,
str
):
dtype
=
dtype
.
lower
()
if
dtype
==
"auto"
:
# Set default dtype from model config
torch_dtype
=
_resolve_auto_dtype
(
model_type
,
config_dtype
,
is_pooling_model
=
is_pooling_model
,
)
else
:
if
dtype
not
in
_STR_DTYPE_TO_TORCH_DTYPE
:
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
raise
ValueError
(
f
"Unknown dtype:
{
dtype
!
r
}
"
)
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
elif
isinstance
(
dtype
,
torch
.
dtype
):
torch_dtype
=
dtype
else
:
raise
ValueError
(
f
"Unknown dtype:
{
dtype
}
"
)
# Verify the dtype.
_check_valid_dtype
(
model_type
,
torch_dtype
)
if
torch_dtype
!=
config_dtype
:
if
torch_dtype
==
torch
.
float32
:
# Upcasting to float32 is allowed.
logger
.
info
(
"Upcasting %s to %s."
,
config_dtype
,
torch_dtype
)
pass
elif
config_dtype
==
torch
.
float32
:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger
.
info
(
"Downcasting %s to %s."
,
config_dtype
,
torch_dtype
)
pass
else
:
# Casting between float16 and bfloat16 is allowed with a warning.
logger
.
warning
(
"Casting %s to %s."
,
config_dtype
,
torch_dtype
)
...
...
vllm/platforms/cpu.py
View file @
6aa8f9a4
...
...
@@ -28,7 +28,7 @@ class CpuPlatform(Platform):
dispatch_key
:
str
=
"CPU"
@
property
def
supported_dtypes
(
self
)
->
list
:
def
supported_dtypes
(
self
)
->
list
[
torch
.
dtype
]
:
if
self
.
get_cpu_architecture
()
==
CpuArchEnum
.
POWERPC
:
return
[
torch
.
bfloat16
,
torch
.
float32
]
elif
sys
.
platform
.
startswith
(
...
...
vllm/transformers_utils/config.py
View file @
6aa8f9a4
...
...
@@ -4,12 +4,12 @@ import enum
import
json
import
os
import
time
from
functools
import
cache
from
functools
import
cache
,
partial
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
import
huggingface_hub
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
get_safetensors_metadata
,
hf_hub_download
from
huggingface_hub
import
list_repo_files
as
hf_list_repo_files
from
huggingface_hub
import
try_to_load_from_cache
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
...
...
@@ -93,10 +93,15 @@ class ConfigFormat(str, enum.Enum):
MISTRAL
=
"mistral"
def
with_retry
(
func
:
Callable
[[],
Any
],
_R
=
TypeVar
(
"_R"
)
def
with_retry
(
func
:
Callable
[[],
_R
],
log_msg
:
str
,
max_retries
:
int
=
2
,
retry_delay
:
int
=
2
):
retry_delay
:
int
=
2
,
)
->
_R
:
for
attempt
in
range
(
max_retries
):
try
:
return
func
()
...
...
@@ -109,6 +114,8 @@ def with_retry(func: Callable[[], Any],
time
.
sleep
(
retry_delay
)
retry_delay
*=
2
raise
AssertionError
(
"Should not be reached"
)
# @cache doesn't cache exceptions
@
cache
...
...
@@ -840,3 +847,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return
resolve_obj_by_qualname
(
function_name
)()
else
:
return
nn
.
Sigmoid
()
if
config
.
num_labels
==
1
else
nn
.
Identity
()
def
try_get_safetensors_metadata
(
model
:
str
,
*
,
revision
:
Optional
[
str
]
=
None
,
):
get_safetensors_metadata_partial
=
partial
(
get_safetensors_metadata
,
model
,
revision
=
revision
,
token
=
os
.
getenv
(
'HF_TOKEN'
,
None
),
)
try
:
return
with_retry
(
get_safetensors_metadata_partial
,
"Error retrieving safetensors"
)
except
Exception
:
return
None
vllm/utils.py
View file @
6aa8f9a4
...
...
@@ -37,8 +37,8 @@ from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
_ArgumentGroup
)
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
KeysView
,
Mapping
)
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Collection
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
KeysView
,
Mapping
)
from
concurrent.futures.process
import
ProcessPoolExecutor
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
...
...
@@ -979,6 +979,53 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
# bool = 0, int = 1, float = 2, complex = 3
def
_get_precision_level
(
dtype
:
torch
.
dtype
)
->
int
:
# NOTE: Complex dtypes return `is_floating_point=False`
return
((
dtype
!=
torch
.
bool
)
+
dtype
.
is_floating_point
+
dtype
.
is_complex
*
2
)
def
is_lossless_cast
(
src_dtype
:
torch
.
dtype
,
tgt_dtype
:
torch
.
dtype
):
"""
Test whether it is lossless to cast a tensor from
`src_dtype` to `tgt_dtype`.
"""
if
src_dtype
==
tgt_dtype
:
return
True
src_level
=
_get_precision_level
(
src_dtype
)
tgt_level
=
_get_precision_level
(
tgt_dtype
)
if
src_level
<
tgt_level
:
return
True
if
src_level
>
tgt_level
:
return
False
# Compare integral types
if
not
src_dtype
.
is_floating_point
and
not
src_dtype
.
is_complex
:
src_info
=
torch
.
iinfo
(
src_dtype
)
tgt_info
=
torch
.
iinfo
(
tgt_dtype
)
return
src_info
.
min
>=
tgt_info
.
min
and
src_info
.
max
<=
tgt_info
.
max
# Compare floating-point types
src_info
=
torch
.
finfo
(
src_dtype
)
tgt_info
=
torch
.
finfo
(
tgt_dtype
)
return
(
src_info
.
min
>=
tgt_info
.
min
and
src_info
.
max
<=
tgt_info
.
max
and
src_info
.
resolution
>=
tgt_info
.
resolution
)
def
common_broadcastable_dtype
(
dtypes
:
Collection
[
torch
.
dtype
]):
"""
Get the common `dtype` where all of the other `dtypes` can be
cast to it without losing any information.
"""
return
max
(
dtypes
,
key
=
lambda
dtype
:
sum
(
is_lossless_cast
(
dt
,
dtype
)
for
dt
in
dtypes
),
)
# `collections` helpers
def
is_list_of
(
value
:
object
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment