"...py_test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "96a5e4dd795b675210b0d18f5e9fab69ec69bb6e"
Unverified Commit 13b3b90a authored by Pavel Iakubovskii's avatar Pavel Iakubovskii Committed by GitHub
Browse files

Fix DETA save_pretrained (#30326)

* Add class_embed to tied weights for DETA

* Fix test_tied_weights_keys for DETA model

* Replace error raise with assert statement
parent 6c7335e0
......@@ -1888,7 +1888,7 @@ class DetaModel(DetaPreTrainedModel):
)
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.\d+"]
_tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"]
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None
......
......@@ -15,8 +15,10 @@
""" Testing suite for the PyTorch DETA model. """
import collections
import inspect
import math
import re
import unittest
from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available
......@@ -32,6 +34,8 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers.pytorch_utils import id_tensor_storage
if is_torchvision_available():
from transformers import DetaForObjectDetection, DetaModel
......@@ -520,6 +524,43 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys
def test_tied_weights_keys(self):
for model_class in self.all_model_classes:
# We need to pass model class name to correctly initialize the config.
# If we don't pass it, the config for `DetaForObjectDetection`` will be initialized
# with `two_stage=False` and the test will fail because for that case `class_embed`
# weights are not tied.
config, _ = self.model_tester.prepare_config_and_inputs_for_common(model_class_name=model_class.__name__)
config.tie_word_embeddings = True
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
TOLERANCE = 1e-4
......
......@@ -2025,8 +2025,8 @@ class ModelTesterMixin:
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment