Unverified Commit 005ba458 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Set torch default dtype in a context manager (#971)


Signed-off-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent 320a622e
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib
from typing import Type from typing import Type
import torch import torch
...@@ -30,6 +31,15 @@ _MODEL_REGISTRY = { ...@@ -30,6 +31,15 @@ _MODEL_REGISTRY = {
} }
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
...@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
torch.set_default_dtype(model_config.dtype) with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# Create a model instance. # The weights will be initialized as empty tensors.
# The weights will be initialized as empty tensors. model = model_class(model_config.hf_config)
model = model_class(model_config.hf_config) if model_config.use_dummy_weights:
if model_config.use_dummy_weights: model = model.cuda()
model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights.
# random values to the weights. initialize_dummy_weights(model)
initialize_dummy_weights(model) else:
else: # Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir,
model.load_weights(model_config.model, model_config.download_dir, model_config.use_np_weights)
model_config.use_np_weights) model = model.cuda()
model = model.cuda()
return model.eval() return model.eval()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment