"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9972562d33dd6455fddc41d1f164c5b84801a46a"
Unverified Commit 7682e977 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[models] respect dtype of the model when instantiating it (#12316)



* [models] respect dtype of the model when instantiating it

* cleanup

* cleanup

* rework to handle non-float dtype

* fix

* switch to fp32 tiny model

* improve

* use dtype.is_floating_point

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix the doc

* recode to use explicit torch_dtype_auto_detect, torch_dtype args

* docs and tweaks

* docs and tweaks

* docs and tweaks

* merge 2 args, add docs

* fix

* fix

* better doc

* better doc
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 31c3e7e7
...@@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU ...@@ -1549,6 +1549,8 @@ Note: If the fp16 weights of the model can't fit onto the memory of a single GPU
For full details on this method and other related features please refer to `Constructing Massive Models For full details on this method and other related features please refer to `Constructing Massive Models
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__. <https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
Also when loading fp16-pretrained models, you will want to tell ``from_pretrained`` to use
``torch_dtype=torch.float16``. For details, please, see :ref:`from_pretrained-torch-dtype`.
Gathering Parameters Gathering Parameters
......
.. ..
Copyright 2020 The HuggingFace Team. All rights reserved. Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
...@@ -38,6 +38,37 @@ PreTrainedModel ...@@ -38,6 +38,37 @@ PreTrainedModel
:members: :members:
.. _from_pretrained-torch-dtype:
Model Instantiation dtype
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Under Pytorch a model normally gets instantiated with ``torch.float32`` format. This can be an issue if one tries to
load a model whose weights are in fp16, since it'd require twice as much memory. To overcome this limitation, you can
either explicitly pass the desired ``dtype`` using ``torch_dtype`` argument:
.. code-block:: python
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype=torch.float16)
or, if you want the model to always load in the most optimal memory pattern, you can use the special value ``"auto"``,
and then ``dtype`` will be automatically derived from the model's weights:
.. code-block:: python
model = T5ForConditionalGeneration.from_pretrained("t5", torch_dtype="auto")
Models instantiated from scratch can also be told which ``dtype`` to use with:
.. code-block:: python
config = T5Config.from_pretrained("t5")
model = AutoModel.from_config(config)
Due to Pytorch design, this functionality is only available for floating dtypes.
ModuleUtilsMixin ModuleUtilsMixin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -192,6 +192,12 @@ class PretrainedConfig(PushToHubMixin): ...@@ -192,6 +192,12 @@ class PretrainedConfig(PushToHubMixin):
- **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and - **tie_word_embeddings** (:obj:`bool`, `optional`, defaults to :obj:`True`) -- Whether the model's input and
output word embeddings should be tied. Note that this is only relevant if the model has a output word output word embeddings should be tied. Note that this is only relevant if the model has a output word
embedding layer. embedding layer.
- **torch_dtype** (:obj:`str`, `optional`) -- The :obj:`dtype` of the weights. This attribute can be used to
initialize the model to a non-default ``dtype`` (which is normally ``float32``) and thus allow for optimal
storage allocation. For example, if the saved model is ``float16``, ideally we want to load it back using the
minimal amount of memory needed to load ``float16`` weights. Since the config object is stored in plain text,
this attribute contains just the floating type string without the ``torch.`` prefix. For example, for
``torch.float16`` ``torch_dtype`` is the ``"float16"`` string.
TensorFlow specific parameters TensorFlow specific parameters
...@@ -207,6 +213,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -207,6 +213,7 @@ class PretrainedConfig(PushToHubMixin):
self.output_hidden_states = kwargs.pop("output_hidden_states", False) self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False) self.output_attentions = kwargs.pop("output_attentions", False)
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False) self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {}) self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop( self.tie_word_embeddings = kwargs.pop(
......
...@@ -111,6 +111,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -111,6 +111,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
raise NotImplementedError(f"init method has to be implemented for {self}") raise NotImplementedError(f"init method has to be implemented for {self}")
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)
@property @property
def config(self) -> PretrainedConfig: def config(self) -> PretrainedConfig:
return self._config return self._config
......
...@@ -643,6 +643,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -643,6 +643,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.config = config self.config = config
self.name_or_path = config.name_or_path self.name_or_path = config.name_or_path
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)
@tf.function( @tf.function(
input_signature=[ input_signature=[
{ {
......
...@@ -23,7 +23,7 @@ from dataclasses import dataclass ...@@ -23,7 +23,7 @@ from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
from torch import Tensor, device, dtype, nn from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .activations import get_activation from .activations import get_activation
...@@ -201,7 +201,7 @@ class ModuleUtilsMixin: ...@@ -201,7 +201,7 @@ class ModuleUtilsMixin:
return get_parameter_device(self) return get_parameter_device(self)
@property @property
def dtype(self) -> dtype: def dtype(self) -> torch.dtype:
""" """
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
""" """
...@@ -464,6 +464,66 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -464,6 +464,66 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.config = config self.config = config
self.name_or_path = config.name_or_path self.name_or_path = config.name_or_path
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
Args:
torch_dtype (:obj:`torch.dtype`, `optional`):
Override the default ``torch.dtype`` and load the model under this dtype.
"""
torch_dtype = kwargs.pop("torch_dtype", None)
# override default dtype if needed
dtype_orig = None
if torch_dtype is not None:
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if is_deepspeed_zero3_enabled():
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config=deepspeed_config()):
model = cls(config, **kwargs)
else:
model = cls(config, **kwargs)
# restore default dtype if it was modified
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
return model
@classmethod
def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
"""
Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
under specific dtype.
Args:
dtype (:obj:`torch.dtype`):
a floating dtype to set to.
Returns:
:obj:`torch.dtype`: the original ``dtype`` that can be used to restore ``torch.set_default_dtype(dtype)``
if it was modified. If it wasn't, returns :obj:`None`.
Note ``set_default_dtype`` currently only works with floating-point types and asserts if for example,
``torch.int64`` is passed. So if a non-float ``dtype`` is passed this functions will throw an exception.
"""
if not dtype.is_floating_point:
raise ValueError(
f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
)
logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
dtype_orig = torch.get_default_dtype()
torch.set_default_dtype(dtype)
return dtype_orig
@property @property
def base_model(self) -> nn.Module: def base_model(self) -> nn.Module:
""" """
...@@ -876,6 +936,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -876,6 +936,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Only save the model itself if we are using distributed training # Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self) model_to_save = unwrap_model(self)
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
# we currently don't use this setting automatically, but may start to use with v5
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = str(dtype).split(".")[1]
# Attach architecture to the config # Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__] model_to_save.config.architectures = [model_to_save.__class__.__name__]
...@@ -993,6 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -993,6 +1058,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Please refer to the mirror site for more information. Please refer to the mirror site for more information.
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`): _fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
Whether or not to disable fast initialization. Whether or not to disable fast initialization.
torch_dtype (:obj:`str` or :obj:`torch.dtype`, `optional`):
Override the default ``torch.dtype`` and load the model under this dtype. If ``"auto"`` is passed the
dtype will be automatically derived from the model's weights.
.. warning:: .. warning::
...@@ -1058,6 +1126,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1058,6 +1126,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True) _fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None)
from_pt = not (from_tf | from_flax)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None: if from_pipeline is not None:
...@@ -1162,6 +1233,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1162,6 +1233,34 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
resolved_archive_file = None resolved_archive_file = None
# load pt weights early so that we know which dtype to init the model under
if from_pt:
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
# weights entry - we assume all weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
torch_dtype = next(iter(state_dict.values())).dtype
else:
raise ValueError(
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
# Instantiate model. # Instantiate model.
...@@ -1178,6 +1277,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1178,6 +1277,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with no_init_weights(_enable=_fast_init): with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if from_tf: if from_tf:
if resolved_archive_file.endswith(".index"): if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors # Load from a TensorFlow 1.X checkpoint - provided by original authors
...@@ -1205,17 +1309,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1205,17 +1309,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions." "https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
) )
raise raise
else: elif from_pt:
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model( model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
) )
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import types import types
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from ...file_utils import copy_func from ...file_utils import copy_func
from ...utils import logging from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
...@@ -367,16 +366,8 @@ class _BaseAutoModelClass: ...@@ -367,16 +366,8 @@ class _BaseAutoModelClass:
def from_config(cls, config, **kwargs): def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys(): if type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping) model_class = _get_model_class(config, cls._model_mapping)
if is_deepspeed_zero3_enabled(): return model_class._from_config(config, **kwargs)
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config=deepspeed_config()):
return model_class(config, **kwargs)
else:
return model_class(config, **kwargs)
raise ValueError( raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
......
...@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple ...@@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
from huggingface_hub import HfApi from huggingface_hub import HfApi
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import is_torch_available, logging from transformers import AutoModel, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -33,6 +33,7 @@ from transformers.testing_utils import ( ...@@ -33,6 +33,7 @@ from transformers.testing_utils import (
PASS, PASS,
USER, USER,
CaptureLogger, CaptureLogger,
TestCasePlus,
is_staging_test, is_staging_test,
require_torch, require_torch,
require_torch_multi_gpu, require_torch_multi_gpu,
...@@ -63,6 +64,7 @@ if is_torch_available(): ...@@ -63,6 +64,7 @@ if is_torch_available():
BertModel, BertModel,
PretrainedConfig, PretrainedConfig,
PreTrainedModel, PreTrainedModel,
T5Config,
T5ForConditionalGeneration, T5ForConditionalGeneration,
) )
...@@ -1574,7 +1576,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): ...@@ -1574,7 +1576,7 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
@require_torch @require_torch
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(TestCasePlus):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...@@ -1607,6 +1609,60 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -1607,6 +1609,60 @@ class ModelUtilsTest(unittest.TestCase):
BertModel.from_pretrained(TINY_T5) BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
@require_torch
def test_model_from_config_torch_dtype(self):
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
# model from the config object.
config = T5Config.from_pretrained(TINY_T5)
model = AutoModel.from_config(config)
# XXX: isn't supported
# model = T5ForConditionalGeneration.from_config(config)
self.assertEqual(model.dtype, torch.float32)
model = AutoModel.from_config(config, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = AutoModel.from_config(config, torch_dtype=torch.int64)
@require_torch
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. config.torch_dtype setting in the saved model (priority)
# 2. via autodiscovery by looking at model weights
# so if a model.half() was saved, we want it to be instantiated as such.
model_path = self.get_auto_remove_tmp_dir()
# baseline - we know TINY_T5 is fp32 model
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertEqual(model.dtype, torch.float32)
# test the default fp32 save_pretrained => from_pretrained cycle
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
self.assertEqual(model.dtype, torch.float32)
# test with auto-detection
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.dtype, torch.float32)
# test forced loading in fp16 (even though the weights are in fp32)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
# test fp16 save_pretrained, loaded with auto-detection
model = model.half()
model.save_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
self.assertEqual(model.config.torch_dtype, "float16") # tests `config.torch_dtype` saving
self.assertEqual(model.dtype, torch.float16)
# test fp16 save_pretrained, loaded with the explicit fp16
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment