Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
005ba458
Unverified
Commit
005ba458
authored
Sep 06, 2023
by
Antoni Baum
Committed by
GitHub
Sep 07, 2023
Browse files
Set torch default dtype in a context manager (#971)
Signed-off-by:
Antoni Baum
<
antoni.baum@protonmail.com
>
parent
320a622e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
15 deletions
+24
-15
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+24
-15
No files found.
vllm/model_executor/model_loader.py
View file @
005ba458
"""Utilities for selecting and loading models."""
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Type
from
typing
import
Type
import
torch
import
torch
...
@@ -30,6 +31,15 @@ _MODEL_REGISTRY = {
...
@@ -30,6 +31,15 @@ _MODEL_REGISTRY = {
}
}
@
contextlib
.
contextmanager
def
_set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
for
arch
in
architectures
:
...
@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
...
@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
torch
.
set_default_dtype
(
model_config
.
dtype
)
with
_set_default_torch_dtype
(
model_config
.
dtype
):
# Create a model instance.
# Create a model instance.
# The weights will be initialized as empty tensors.
# The weights will be initialized as empty tensors.
model
=
model_class
(
model_config
.
hf_config
)
model
=
model_class
(
model_config
.
hf_config
)
if
model_config
.
use_dummy_weights
:
if
model_config
.
use_dummy_weights
:
model
=
model
.
cuda
()
model
=
model
.
cuda
()
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
)
else
:
else
:
# Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
use_np_weights
)
model_config
.
use_np_weights
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
return
model
.
eval
()
return
model
.
eval
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment