Unverified Commit 90e8263d authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add methods to update and verify out_features out_indices (#23031)

* Add methods to update and verify out_features out_indices

* Safe update for config attributes

* Fix function names

* Save config correctly

* PR comments - use property setters

* PR comment - directly set attributes

* Update test

* Add updates to recently merged focalnet backbone
parent 78b7debf
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Collection of utils to be used by backbones and their components."""
import inspect
from typing import Iterable, List, Optional, Tuple, Union
def verify_out_features_out_indices(
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
):
"""
Verify that out_indices and out_features are valid for the given stage_names.
"""
if stage_names is None:
raise ValueError("Stage_names must be set for transformers backbones")
if out_features is not None:
if not isinstance(out_features, (list,)):
raise ValueError(f"out_features must be a list {type(out_features)}")
if any(feat not in stage_names for feat in out_features):
raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}")
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}")
if any(idx >= len(stage_names) for idx in out_indices):
raise ValueError("out_indices must be valid indices for stage_names {stage_names}, got {out_indices}")
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
if out_features != [stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
def _align_output_features_output_indices(
out_features: Optional[List[str]],
out_indices: Optional[Union[List[int], Tuple[int]]],
stage_names: List[str],
):
"""
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
The logic is as follows:
- `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
`out_indices`.
- `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
`out_features`.
- `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
- `out_indices` and `out_features` set: input `out_indices` and `out_features` are returned.
Args:
out_features (`List[str]`): The names of the features for the backbone to output.
out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
stage_names (`List[str]`): The names of the stages of the backbone.
"""
if out_indices is None and out_features is None:
out_indices = [len(stage_names) - 1]
out_features = [stage_names[-1]]
elif out_indices is None and out_features is not None:
out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features]
elif out_features is None and out_indices is not None:
out_features = [stage_names[idx] for idx in out_indices]
return out_features, out_indices
def get_aligned_output_features_output_indices(
out_features: Optional[List[str]],
out_indices: Optional[Union[List[int], Tuple[int]]],
stage_names: List[str],
) -> Tuple[List[str], List[int]]:
"""
Get the `out_features` and `out_indices` so that they are aligned.
The logic is as follows:
- `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
`out_indices`.
- `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
`out_features`.
- `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
- `out_indices` and `out_features` set: they are verified to be aligned.
Args:
out_features (`List[str]`): The names of the features for the backbone to output.
out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
stage_names (`List[str]`): The names of the stages of the backbone.
"""
# First verify that the out_features and out_indices are valid
verify_out_features_out_indices(out_features=out_features, out_indices=out_indices, stage_names=stage_names)
output_features, output_indices = _align_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=stage_names
)
# Verify that the aligned out_features and out_indices are valid
verify_out_features_out_indices(out_features=output_features, out_indices=output_indices, stage_names=stage_names)
return output_features, output_indices
class BackboneMixin:
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
def forward(
self,
pixel_values,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
@property
def out_features(self):
return self._out_features
@out_features.setter
def out_features(self, out_features: List[str]):
"""
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=None, stage_names=self.stage_names
)
@property
def out_indices(self):
return self._out_indices
@out_indices.setter
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
"""
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=None, out_indices=out_indices, stage_names=self.stage_names
)
class BackboneConfigMixin:
"""
A Mixin to support handling the `out_features` and `out_indices` attributes for the backbone configurations.
"""
@property
def out_features(self):
return self._out_features
@out_features.setter
def out_features(self, out_features: List[str]):
"""
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=None, stage_names=self.stage_names
)
@property
def out_indices(self):
return self._out_indices
@out_indices.setter
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
"""
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
"""
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=None, out_indices=out_indices, stage_names=self.stage_names
)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
include the `out_features` and `out_indices` attributes.
"""
output = super().to_dict()
output["out_features"] = output.pop("_out_features")
output["out_indices"] = output.pop("_out_indices")
return output
......@@ -81,9 +81,15 @@ class BackboneTesterMixin:
out_channels = [num_features[idx] for idx in out_indices]
self.assertListEqual(model.channels, out_channels)
config.out_features = None
config.out_indices = None
model = model_class(config)
new_config = copy.deepcopy(config)
new_config.out_features = None
model = model_class(new_config)
self.assertEqual(len(model.channels), 1)
self.assertListEqual(model.channels, [num_features[-1]])
new_config = copy.deepcopy(config)
new_config.out_indices = None
model = model_class(new_config)
self.assertEqual(len(model.channels), 1)
self.assertListEqual(model.channels, [num_features[-1]])
......@@ -102,6 +108,15 @@ class BackboneTesterMixin:
# Check output of last stage is taken if out_features=None, out_indices=None
modified_config = copy.deepcopy(config)
modified_config.out_features = None
model = model_class(modified_config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)
self.assertEqual(len(result.feature_maps), 1)
self.assertEqual(len(model.channels), 1)
modified_config = copy.deepcopy(config)
modified_config.out_indices = None
model = model_class(modified_config)
model.to(torch_device)
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.utils.backbone_utils import (
BackboneMixin,
get_aligned_output_features_output_indices,
verify_out_features_out_indices,
)
class BackboneUtilsTester(unittest.TestCase):
def test_get_aligned_output_features_output_indices(self):
stage_names = ["a", "b", "c"]
# Defaults to last layer if both are None
out_features, out_indices = get_aligned_output_features_output_indices(None, None, stage_names)
self.assertEqual(out_features, ["c"])
self.assertEqual(out_indices, [2])
# Out indices set to match out features
out_features, out_indices = get_aligned_output_features_output_indices(["a", "c"], None, stage_names)
self.assertEqual(out_features, ["a", "c"])
self.assertEqual(out_indices, [0, 2])
# Out features set to match out indices
out_features, out_indices = get_aligned_output_features_output_indices(None, [0, 2], stage_names)
self.assertEqual(out_features, ["a", "c"])
self.assertEqual(out_indices, [0, 2])
# Out features selected from negative indices
out_features, out_indices = get_aligned_output_features_output_indices(None, [-3, -1], stage_names)
self.assertEqual(out_features, ["a", "c"])
self.assertEqual(out_indices, [-3, -1])
def test_verify_out_features_out_indices(self):
# Stage names must be set
with self.assertRaises(ValueError):
verify_out_features_out_indices(["a", "b"], (0, 1), None)
# Out features must be a list
with self.assertRaises(ValueError):
verify_out_features_out_indices(("a", "b"), (0, 1), ["a", "b"])
# Out features must be a subset of stage names
with self.assertRaises(ValueError):
verify_out_features_out_indices(["a", "b"], (0, 1), ["a"])
# Out indices must be a list or tuple
with self.assertRaises(ValueError):
verify_out_features_out_indices(None, 0, ["a", "b"])
# Out indices must be a subset of stage names
with self.assertRaises(ValueError):
verify_out_features_out_indices(None, (0, 1), ["a"])
# Out features and out indices must be the same length
with self.assertRaises(ValueError):
verify_out_features_out_indices(["a", "b"], (0,), ["a", "b", "c"])
# Out features should match out indices
with self.assertRaises(ValueError):
verify_out_features_out_indices(["a", "b"], (0, 2), ["a", "b", "c"])
# Out features and out indices should be in order
with self.assertRaises(ValueError):
verify_out_features_out_indices(["b", "a"], (0, 1), ["a", "b"])
# Check passes with valid inputs
verify_out_features_out_indices(["a", "b", "d"], (0, 1, -1), ["a", "b", "c", "d"])
def test_backbone_mixin(self):
backbone = BackboneMixin()
backbone.stage_names = ["a", "b", "c"]
backbone._out_features = ["a", "c"]
backbone._out_indices = [0, 2]
# Check that the output features and indices are set correctly
self.assertEqual(backbone.out_features, ["a", "c"])
self.assertEqual(backbone.out_indices, [0, 2])
# Check out features and indices are updated correctly
backbone.out_features = ["a", "b"]
self.assertEqual(backbone.out_features, ["a", "b"])
self.assertEqual(backbone.out_indices, [0, 1])
backbone.out_indices = [-3, -1]
self.assertEqual(backbone.out_features, ["a", "c"])
self.assertEqual(backbone.out_indices, [-3, -1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment