Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f690372b
Unverified
Commit
f690372b
authored
Mar 19, 2025
by
Cyrus Leung
Committed by
GitHub
Mar 19, 2025
Browse files
[Core] Update dtype detection and defaults (#14858)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
8b3e94a3
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
12 deletions
+12
-12
tests/v1/engine/test_llm_engine.py
tests/v1/engine/test_llm_engine.py
+1
-1
vllm/config.py
vllm/config.py
+11
-11
No files found.
tests/v1/engine/test_llm_engine.py
View file @
f690372b
...
...
@@ -50,7 +50,7 @@ def _get_test_sampling_params(
"""Generate random sampling params for a batch."""
def
get_mostly_n_gt1
()
->
int
:
"""Mostly n \in [2,20], ~1/3 n=1"""
r
"""Mostly n \in [2,20], ~1/3 n=1"""
x
=
random
.
randint
(
0
,
28
)
if
x
<
10
:
return
1
...
...
vllm/config.py
View file @
f690372b
...
...
@@ -347,7 +347,7 @@ class ModelConfig:
self
.
encoder_config
=
self
.
_get_encoder_config
()
self
.
hf_image_processor_config
=
get_hf_image_processor_config
(
self
.
model
,
revision
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_
text_
config
,
dtype
)
self
.
dtype
=
_get_and_verify_dtype
(
self
.
hf_config
,
dtype
)
self
.
use_async_output_proc
=
use_async_output_proc
self
.
mm_processor_kwargs
=
mm_processor_kwargs
self
.
disable_mm_preprocessor_cache
=
disable_mm_preprocessor_cache
...
...
@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype
=
getattr
(
config
,
"torch_dtype"
,
None
)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
if
config_dtype
is
None
and
hasattr
(
config
,
"text_config"
):
config_dtype
=
getattr
(
config
.
text_config
,
"torch_dtype"
,
None
)
if
config_dtype
is
None
and
hasattr
(
config
,
"vision_config"
):
config_dtype
=
getattr
(
config
.
vision_config
,
"torch_dtype"
,
None
)
if
config_dtype
is
None
:
config_dtype
=
torch
.
float32
...
...
@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
dtype
=
dtype
.
lower
()
if
dtype
==
"auto"
:
if
config_dtype
==
torch
.
float32
:
if
config
.
model_type
in
(
"gemma2"
,
"gemma3"
,
"gemma3_text"
):
logger
.
info
(
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16."
)
torch_dtype
=
torch
.
bfloat16
else
:
# Following the common practice, we use float16 for float32
# models.
torch_dtype
=
torch
.
float16
# Following common practice, we use float16 for float32 models
torch_dtype
=
torch
.
float16
else
:
torch_dtype
=
config_dtype
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment