"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d4e92f1a21c0e4ca1721721e4c7e7a0c32439d64"
Unverified Commit 13b3b90a authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Fix DETA save_pretrained (#30326)

* Add class_embed to tied weights for DETA

* Fix test_tied_weights_keys for DETA model

* Replace error raise with assert statement
parent 6c7335e0
...@@ -1888,7 +1888,7 @@ class DetaModel(DetaPreTrainedModel): ...@@ -1888,7 +1888,7 @@ class DetaModel(DetaPreTrainedModel):
) )
class DetaForObjectDetection(DetaPreTrainedModel): class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required # When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.\d+"] _tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"]
# We can't initialize the model on meta device as some weights are modified during the initialization # We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None _no_split_modules = None
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
""" Testing suite for the PyTorch DETA model. """ """ Testing suite for the PyTorch DETA model. """
import collections
import inspect import inspect
import math import math
import re
import unittest import unittest
from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available
...@@ -32,6 +34,8 @@ from ...test_pipeline_mixin import PipelineTesterMixin ...@@ -32,6 +34,8 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.pytorch_utils import id_tensor_storage
if is_torchvision_available(): if is_torchvision_available():
from transformers import DetaForObjectDetection, DetaModel from transformers import DetaForObjectDetection, DetaModel
...@@ -520,6 +524,43 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -520,6 +524,43 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
msg=f"Parameter {name} of model {model_class} seems not properly initialized", msg=f"Parameter {name} of model {model_class} seems not properly initialized",
) )
# Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys
def test_tied_weights_keys(self):
for model_class in self.all_model_classes:
# We need to pass model class name to correctly initialize the config.
# If we don't pass it, the config for `DetaForObjectDetection`` will be initialized
# with `two_stage=False` and the test will fail because for that case `class_embed`
# weights are not tied.
config, _ = self.model_tester.prepare_config_and_inputs_for_common(model_class_name=model_class.__name__)
config.tie_word_embeddings = True
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
TOLERANCE = 1e-4 TOLERANCE = 1e-4
......
...@@ -2025,8 +2025,8 @@ class ModelTesterMixin: ...@@ -2025,8 +2025,8 @@ class ModelTesterMixin:
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key # Detect we get a hit for each key
for key in tied_weight_keys: for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group): is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
raise ValueError(f"{key} is not a tied weight key for {model_class}.") self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after # Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys: for key in tied_weight_keys:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment