"docs/source/vscode:/vscode.git/clone" did not exist on "33da2db5eafc83c9fa4e5ae9abd2f5636a3c6616"
Unverified Commit 9d98706b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix failed tests in #31851 (#31879)

* Revert "Revert "Fix `_init_weights` for `ResNetPreTrainedModel`" (#31868)"

This reverts commit b45dd5de

.

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

* fix

* [test_all] check

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a0a3e2f4
...@@ -660,6 +660,13 @@ class BitPreTrainedModel(PreTrainedModel): ...@@ -660,6 +660,13 @@ class BitPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1) nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch RegNet model.""" """PyTorch RegNet model."""
import math
from typing import Optional from typing import Optional
import torch import torch
...@@ -284,6 +285,13 @@ class RegNetPreTrainedModel(PreTrainedModel): ...@@ -284,6 +285,13 @@ class RegNetPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1) nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""PyTorch ResNet model.""" """PyTorch ResNet model."""
import math
from typing import Optional from typing import Optional
import torch import torch
...@@ -274,6 +275,13 @@ class ResNetPreTrainedModel(PreTrainedModel): ...@@ -274,6 +275,13 @@ class ResNetPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1) nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
......
...@@ -17,6 +17,7 @@ PyTorch RTDetr specific ResNet model. The main difference between hugginface Res ...@@ -17,6 +17,7 @@ PyTorch RTDetr specific ResNet model. The main difference between hugginface Res
See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details. See https://github.com/lyuwenyu/RT-DETR/blob/5b628eaa0a2fc25bdafec7e6148d5296b144af85/rtdetr_pytorch/src/nn/backbone/presnet.py#L126 for details.
""" """
import math
from typing import Optional from typing import Optional
from torch import Tensor, nn from torch import Tensor, nn
...@@ -323,6 +324,13 @@ class RTDetrResNetPreTrainedModel(PreTrainedModel): ...@@ -323,6 +324,13 @@ class RTDetrResNetPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
# copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
elif isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(module.bias, -bound, bound)
elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(module.weight, 1) nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
......
...@@ -3167,9 +3167,68 @@ class ModelTesterMixin: ...@@ -3167,9 +3167,68 @@ class ModelTesterMixin:
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES): mappings = [
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES,
]
is_classication_model = any(model_class.__name__ in get_values(mapping) for mapping in mappings)
if not is_classication_model:
continue continue
# TODO: ydshieh
is_special_classes = model_class.__name__ in [
"wav2vec2.masked_spec_embed",
"Wav2Vec2ForSequenceClassification",
"CLIPForImageClassification",
"RegNetForImageClassification",
"ResNetForImageClassification",
"UniSpeechSatForSequenceClassification",
"Wav2Vec2BertForSequenceClassification",
"PvtV2ForImageClassification",
"Wav2Vec2ConformerForSequenceClassification",
"WavLMForSequenceClassification",
"SwiftFormerForImageClassification",
"SEWForSequenceClassification",
"BitForImageClassification",
"SEWDForSequenceClassification",
"SiglipForImageClassification",
"HubertForSequenceClassification",
"Swinv2ForImageClassification",
"Data2VecAudioForSequenceClassification",
"UniSpeechForSequenceClassification",
"PvtForImageClassification",
]
special_param_names = [
r"^bit\.",
r"^classifier\.weight",
r"^classifier\.bias",
r"^classifier\..+\.weight",
r"^classifier\..+\.bias",
r"^data2vec_audio\.",
r"^dist_head\.",
r"^head\.",
r"^hubert\.",
r"^pvt\.",
r"^pvt_v2\.",
r"^regnet\.",
r"^resnet\.",
r"^sew\.",
r"^sew_d\.",
r"^swiftformer\.",
r"^swinv2\.",
r"^transformers\.models\.swiftformer\.",
r"^unispeech\.",
r"^unispeech_sat\.",
r"^vision_model\.",
r"^wav2vec2\.",
r"^wav2vec2_bert\.",
r"^wav2vec2_conformer\.",
r"^wavlm\.",
]
with self.subTest(msg=f"Testing {model_class}"): with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(configs_no_init) model = model_class(configs_no_init)
...@@ -3177,23 +3236,41 @@ class ModelTesterMixin: ...@@ -3177,23 +3236,41 @@ class ModelTesterMixin:
# Fails when we don't set ignore_mismatched_sizes=True # Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) new_model = model_class.from_pretrained(tmp_dir, num_labels=42)
logger = logging.get_logger("transformers.modeling_utils") logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
new_model = AutoModelForSequenceClassification.from_pretrained( new_model = model_class.from_pretrained(tmp_dir, num_labels=42, ignore_mismatched_sizes=True)
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out) self.assertIn("the shapes did not match", cl.out)
for name, param in new_model.named_parameters(): for name, param in new_model.named_parameters():
if param.requires_grad: if param.requires_grad:
param_mean = ((param.data.mean() * 1e9).round() / 1e9).item()
if not (
is_special_classes
and any(len(re.findall(target, name)) > 0 for target in special_param_names)
):
self.assertIn( self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(), param_mean,
[0.0, 1.0], [0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
else:
# Here we allow the parameters' mean to be in the range [-5.0, 5.0] instead of being
# either `0.0` or `1.0`, because their initializations are not using
# `config.initializer_factor` (or something similar). The purpose of this test is simply
# to make sure they are properly initialized (to avoid very large value or even `nan`).
self.assertGreaterEqual(
param_mean,
-5.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
self.assertLessEqual(
param_mean,
5.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self): def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment