Unverified Commit 005ba458 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Set torch default dtype in a context manager (#971)


Signed-off-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent 320a622e
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib
from typing import Type from typing import Type
import torch import torch
...@@ -30,6 +31,15 @@ _MODEL_REGISTRY = { ...@@ -30,6 +31,15 @@ _MODEL_REGISTRY = {
} }
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
...@@ -42,8 +52,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: ...@@ -42,8 +52,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
torch.set_default_dtype(model_config.dtype) with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config) model = model_class(model_config.hf_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment