Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
005ba458
Unverified
Commit
005ba458
authored
Sep 06, 2023
by
Antoni Baum
Committed by
GitHub
Sep 07, 2023
Browse files
Set torch default dtype in a context manager (#971)
Signed-off-by:
Antoni Baum
<
antoni.baum@protonmail.com
>
parent
320a622e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
15 deletions
+24
-15
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+24
-15
No files found.
vllm/model_executor/model_loader.py
View file @
005ba458
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Type
import
torch
...
...
@@ -30,6 +31,15 @@ _MODEL_REGISTRY = {
}
@
contextlib
.
contextmanager
def
_set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
...
...
@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
torch
.
set_default_dtype
(
model_config
.
dtype
)
# Create a model instance.
# The weights will be initialized as empty tensors.
model
=
model_class
(
model_config
.
hf_config
)
if
model_config
.
use_dummy_weights
:
model
=
model
.
cuda
()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
else
:
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model
=
model
.
cuda
()
with
_set_default_torch_dtype
(
model_config
.
dtype
):
# Create a model instance.
# The weights will be initialized as empty tensors.
model
=
model_class
(
model_config
.
hf_config
)
if
model_config
.
use_dummy_weights
:
model
=
model
.
cuda
()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
else
:
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model
=
model
.
cuda
()
return
model
.
eval
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment