"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "c4ae273738c65165d470a6cc9e3de420becbf104"
Unverified Commit 3e3e41ae authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Pytorch - Lazy initialization of models (#11471)



* lazy_init_weights

* remove ipdb

* save int

* add necessary code

* remove unnecessary utils

* Update src/transformers/models/t5/modeling_t5.py

* clean

* add tests

* correct

* finish tests

* finish tests

* fix some more tests

* fix xlnet & transfo-xl

* fix more tests

* make sure tests are independent

* fix tests more

* finist tests

* final touches

* Update src/transformers/modeling_utils.py

* Apply suggestions from code review

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* clean tests

* give arg positive name

* add more mock weights to xlnet
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 8fa8e194
...@@ -195,6 +195,7 @@ class ExamplesTests(TestCasePlus): ...@@ -195,6 +195,7 @@ class ExamplesTests(TestCasePlus):
--per_device_train_batch_size=2 --per_device_train_batch_size=2
--per_device_eval_batch_size=2 --per_device_eval_batch_size=2
--num_train_epochs={epochs} --num_train_epochs={epochs}
--seed 7
""".split() """.split()
if torch_device != "cuda": if torch_device != "cuda":
......
...@@ -18,6 +18,7 @@ import inspect ...@@ -18,6 +18,7 @@ import inspect
import os import os
import re import re
import warnings import warnings
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
...@@ -50,6 +51,26 @@ from .utils import logging ...@@ -50,6 +51,26 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_init_weights = True
@contextmanager
def no_init_weights(_enable=True):
"""
Context manager to globally disable weight initialization to speed up loading large models.
TODO(Patrick): Delete safety argument `_enable=True` at next major version. .
"""
global _init_weights
if _enable:
_init_weights = False
try:
yield
finally:
_init_weights = True
try: try:
from torch.nn import Identity from torch.nn import Identity
except ImportError: except ImportError:
...@@ -768,17 +789,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -768,17 +789,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def init_weights(self): def init_weights(self):
""" """
Initializes and prunes weights if needed. If needed prunes and maybe initializes weights.
""" """
# Initialize weights
self.apply(self._init_weights)
# Prune heads if needed # Prune heads if needed
if self.config.pruned_heads: if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads) self.prune_heads(self.config.pruned_heads)
# Tie weights if needed if _init_weights:
self.tie_weights() # Initialize weights
self.apply(self._init_weights)
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
self.tie_weights()
def prune_heads(self, heads_to_prune: Dict[int, List[int]]): def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
""" """
...@@ -956,6 +979,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -956,6 +979,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Mirror source to accelerate downloads in China. If you are from China and have an accessibility Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. Please refer to the mirror site for more information.
_fast_init(:obj:`bool`, `optional`, defaults to `:obj:`True`):
Whether or not to disable fast initialization.
.. warning::
One should only disable `_fast_init` to ensure backwards compatibility with
``transformers.__version__ < 4.6.0`` for seeded model initialization. This argument will be removed
at the next major version. See `pull request 11471
<https://github.com/huggingface/transformers/pull/11471>`__ for more information.
kwargs (remaining dictionary of keyword arguments, `optional`): kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
...@@ -1012,6 +1045,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1012,6 +1045,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mirror = kwargs.pop("mirror", None) mirror = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None: if from_pipeline is not None:
...@@ -1119,7 +1153,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1119,7 +1153,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
# Instantiate model. # Instantiate model.
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed import deepspeed
...@@ -1127,23 +1160,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1127,23 +1160,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# this immediately partitions the model across all gpus, to avoid the overhead in time # this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first # and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config=deepspeed_config()): with deepspeed.zero.Init(config=deepspeed_config()):
model = cls(config, *model_args, **model_kwargs) with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)
else: else:
model = cls(config, *model_args, **model_kwargs) with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs)
if state_dict is None and not (from_tf or from_flax):
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
missing_keys = []
unexpected_keys = []
error_msgs = []
if from_tf: if from_tf:
if resolved_archive_file.endswith(".index"): if resolved_archive_file.endswith(".index"):
...@@ -1173,102 +1194,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1173,102 +1194,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
raise raise
else: else:
# Convert old format to new format if needed from a PyTorch state_dict if state_dict is None:
old_keys = [] try:
new_keys = [] state_dict = torch.load(resolved_archive_file, map_location="cpu")
for key in state_dict.keys(): except Exception:
new_key = None raise OSError(
if "gamma" in key: f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
new_key = key.replace("gamma", "weight") f"at '{resolved_archive_file}'"
if "beta" in key: "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
new_key = key.replace("beta", "bias") )
if new_key:
old_keys.append(key) model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
new_keys.append(new_key) model, state_dict, pretrained_model_name_or_path
for old_key, new_key in zip(old_keys, new_keys): )
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if model.__class__.__name__ != model_to_load.__class__.__name__:
base_model_state_dict = model_to_load.state_dict().keys()
head_model_state_dict_without_base_prefix = [
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
]
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
# make sure token embedding weights are still tied if needed # make sure token embedding weights are still tied if needed
model.tie_weights() model.tie_weights()
...@@ -1285,6 +1224,142 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1285,6 +1224,142 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model return model
@classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# Retrieve missing & unexpected_keys
expected_keys = list(model.state_dict().keys())
loaded_keys = list(state_dict.keys())
prefix = model.base_model_prefix
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
remove_prefix = not has_prefix_module and expects_prefix_module
add_prefix = has_prefix_module and not expects_prefix_module
if remove_prefix:
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
elif add_prefix:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
# tie unintialized modules
unintialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix, remove_prefix=remove_prefix
)
for module in unintialized_modules:
model._init_weights(module)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
return model, missing_keys, unexpected_keys, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names])
retrieved_modules = []
# retrieve all modules that has at least one missing weight name
for name, module in self.named_modules():
if remove_prefix:
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
elif add_prefix:
name = ".".join([self.base_model_prefix, name])
if name in module_keys:
retrieved_modules.append(module)
return retrieved_modules
class Conv1D(nn.Module): class Conv1D(nn.Module):
""" """
......
...@@ -177,6 +177,103 @@ class ModelTesterMixin: ...@@ -177,6 +177,103 @@ class ModelTesterMixin:
for k in _keys_to_ignore_on_save: for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved) self.assertNotIn(k, state_dict_saved)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(model_class):
pass
model_class_copy = CopyClass
# make sure that all keys are expected for test
model_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights
model = base_class(config)
state_dict = model.state_dict()
# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass
base_class_copy = CopyClass
# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = self._mock_init_weights
model = model_class(config)
state_dict = model.state_dict()
# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
model_fast_init = base_class_copy.from_pretrained(tmpdirname)
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -400,6 +400,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -400,6 +400,18 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
@require_torch @require_torch
class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase): class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
...@@ -443,6 +455,18 @@ class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -443,6 +455,18 @@ class FunnelBaseModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
...@@ -348,6 +348,31 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -348,6 +348,31 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
[expected_shape] * len(iter_hidden_states), [expected_shape] * len(iter_hidden_states),
) )
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "cluster_weight") and module.cluster_weight is not None:
module.cluster_weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "cluster_bias") and module.cluster_bias is not None:
module.cluster_bias.data.fill_(3)
if hasattr(module, "emb_projs"):
for i in range(len(module.emb_projs)):
if module.emb_projs[i] is not None:
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
if hasattr(module, "out_projs"):
for i in range(len(module.out_projs)):
if module.out_projs[i] is not None:
torch.nn.init.constant_(module.out_projs[i], 0.0003)
for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
@require_torch @require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase): class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
......
...@@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -329,6 +329,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
...@@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -446,6 +455,15 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "weight_g") and module.weight is not None:
module.weight_g.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
......
...@@ -594,6 +594,18 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -594,6 +594,18 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
# xlnet cannot keep gradients in attentions or hidden states # xlnet cannot keep gradients in attentions or hidden states
return return
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]:
if hasattr(module, param) and getattr(module, param) is not None:
weight = getattr(module, param)
weight.data.fill_(3)
def _check_hidden_states_for_generate( def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1 self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment