Unverified Commit ee298a16 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Align backbone stage selection with out_indices & out_features (#27606)

* Iteratre over out_features instead of stage_names

* Update for all backbones

* Add tests

* Fix

* Align timm backbone behaviour with other backbones

* Fix tests

* Stricter checks on set out_features and out_indices

* Revert back stage selection logic

* Remove out-of-order logic

* Document restriction in docstrings
parent 224ab709
...@@ -102,11 +102,13 @@ class BeitConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -102,11 +102,13 @@ class BeitConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
add_fpn (`bool`, *optional*, defaults to `False`): add_fpn (`bool`, *optional*, defaults to `False`):
Whether to add a FPN as part of the backbone. Only relevant for [`BeitBackbone`]. Whether to add a FPN as part of the backbone. Only relevant for [`BeitBackbone`].
reshape_hidden_states (`bool`, *optional*, defaults to `True`): reshape_hidden_states (`bool`, *optional*, defaults to `True`):
......
...@@ -65,11 +65,13 @@ class BitConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -65,11 +65,13 @@ class BitConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
```python ```python
......
...@@ -68,11 +68,13 @@ class ConvNextConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -68,11 +68,13 @@ class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
```python ```python
......
...@@ -60,11 +60,13 @@ class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig): ...@@ -60,11 +60,13 @@ class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
```python ```python
......
...@@ -74,11 +74,13 @@ class DinatConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -74,11 +74,13 @@ class DinatConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
......
...@@ -79,11 +79,13 @@ class Dinov2Config(BackboneConfigMixin, PretrainedConfig): ...@@ -79,11 +79,13 @@ class Dinov2Config(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
apply_layernorm (`bool`, *optional*, defaults to `True`): apply_layernorm (`bool`, *optional*, defaults to `True`):
Whether to apply layer normalization to the feature maps in case the model is used as backbone. Whether to apply layer normalization to the feature maps in case the model is used as backbone.
reshape_hidden_states (`bool`, *optional*, defaults to `True`): reshape_hidden_states (`bool`, *optional*, defaults to `True`):
......
...@@ -84,11 +84,13 @@ class FocalNetConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -84,11 +84,13 @@ class FocalNetConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
......
...@@ -70,11 +70,13 @@ class MaskFormerSwinConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -70,11 +70,13 @@ class MaskFormerSwinConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
......
...@@ -72,11 +72,13 @@ class NatConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -72,11 +72,13 @@ class NatConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
......
...@@ -64,11 +64,13 @@ class ResNetConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -64,11 +64,13 @@ class ResNetConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
```python ```python
......
...@@ -85,11 +85,13 @@ class SwinConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -85,11 +85,13 @@ class SwinConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
......
...@@ -80,11 +80,13 @@ class VitDetConfig(BackboneConfigMixin, PretrainedConfig): ...@@ -80,11 +80,13 @@ class VitDetConfig(BackboneConfigMixin, PretrainedConfig):
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*): out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example: Example:
......
...@@ -36,15 +36,32 @@ def verify_out_features_out_indices( ...@@ -36,15 +36,32 @@ def verify_out_features_out_indices(
if out_features is not None: if out_features is not None:
if not isinstance(out_features, (list,)): if not isinstance(out_features, (list,)):
raise ValueError(f"out_features must be a list {type(out_features)}") raise ValueError(f"out_features must be a list got {type(out_features)}")
if any(feat not in stage_names for feat in out_features): if any(feat not in stage_names for feat in out_features):
raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}") raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}")
if len(out_features) != len(set(out_features)):
raise ValueError(f"out_features must not contain any duplicates, got {out_features}")
if out_features != (sorted_feats := [feat for feat in stage_names if feat in out_features]):
raise ValueError(
f"out_features must be in the same order as stage_names, expected {sorted_feats} got {out_features}"
)
if out_indices is not None: if out_indices is not None:
if not isinstance(out_indices, (list, tuple)): if not isinstance(out_indices, (list, tuple)):
raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}") raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}")
if any(idx >= len(stage_names) for idx in out_indices): # Convert negative indices to their positive equivalent: [-1,] -> [len(stage_names) - 1,]
positive_indices = tuple(idx % len(stage_names) if idx < 0 else idx for idx in out_indices)
if any(idx for idx in positive_indices if idx not in range(len(stage_names))):
raise ValueError(f"out_indices must be valid indices for stage_names {stage_names}, got {out_indices}") raise ValueError(f"out_indices must be valid indices for stage_names {stage_names}, got {out_indices}")
if len(positive_indices) != len(set(positive_indices)):
msg = f"out_indices must not contain any duplicates, got {out_indices}"
msg += f"(equivalent to {positive_indices}))" if positive_indices != out_indices else ""
raise ValueError(msg)
if positive_indices != tuple(sorted(positive_indices)):
sorted_negative = tuple(idx for _, idx in sorted(zip(positive_indices, out_indices), key=lambda x: x[0]))
raise ValueError(
f"out_indices must be in the same order as stage_names, expected {sorted_negative} got {out_indices}"
)
if out_features is not None and out_indices is not None: if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices): if len(out_features) != len(out_indices):
......
...@@ -201,3 +201,27 @@ class BackboneTesterMixin: ...@@ -201,3 +201,27 @@ class BackboneTesterMixin:
if self.has_attentions: if self.has_attentions:
outputs = backbone(**inputs_dict, output_attentions=True) outputs = backbone(**inputs_dict, output_attentions=True)
self.assertIsNotNone(outputs.attentions) self.assertIsNotNone(outputs.attentions)
def test_backbone_stage_selection(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
batch_size = inputs_dict["pixel_values"].shape[0]
for backbone_class in self.all_model_classes:
config.out_indices = [-2, -1]
backbone = backbone_class(config)
backbone.to(torch_device)
backbone.eval()
outputs = backbone(**inputs_dict)
# Test number of feature maps returned
self.assertIsInstance(outputs.feature_maps, tuple)
self.assertTrue(len(outputs.feature_maps) == 2)
# Order of channels returned is same as order of channels iterating over stage names
channels_from_stage_names = [
backbone.out_feature_channels[name] for name in backbone.stage_names if name in backbone.out_features
]
self.assertEqual(backbone.channels, channels_from_stage_names)
for feature_map, n_channels in zip(outputs.feature_maps, backbone.channels):
self.assertTrue(feature_map.shape[:2], (batch_size, n_channels))
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import unittest import unittest
import pytest
from transformers.utils.backbone_utils import ( from transformers.utils.backbone_utils import (
BackboneMixin, BackboneMixin,
get_aligned_output_features_output_indices, get_aligned_output_features_output_indices,
...@@ -47,37 +49,61 @@ class BackboneUtilsTester(unittest.TestCase): ...@@ -47,37 +49,61 @@ class BackboneUtilsTester(unittest.TestCase):
def test_verify_out_features_out_indices(self): def test_verify_out_features_out_indices(self):
# Stage names must be set # Stage names must be set
with self.assertRaises(ValueError): with pytest.raises(ValueError, match="Stage_names must be set for transformers backbones"):
verify_out_features_out_indices(["a", "b"], (0, 1), None) verify_out_features_out_indices(["a", "b"], (0, 1), None)
# Out features must be a list # Out features must be a list
with self.assertRaises(ValueError): with pytest.raises(ValueError, match="out_features must be a list got <class 'tuple'>"):
verify_out_features_out_indices(("a", "b"), (0, 1), ["a", "b"]) verify_out_features_out_indices(("a", "b"), (0, 1), ["a", "b"])
# Out features must be a subset of stage names # Out features must be a subset of stage names
with self.assertRaises(ValueError): with pytest.raises(
ValueError, match=r"out_features must be a subset of stage_names: \['a'\] got \['a', 'b'\]"
):
verify_out_features_out_indices(["a", "b"], (0, 1), ["a"]) verify_out_features_out_indices(["a", "b"], (0, 1), ["a"])
# Out features must contain no duplicates
with pytest.raises(ValueError, match=r"out_features must not contain any duplicates, got \['a', 'a'\]"):
verify_out_features_out_indices(["a", "a"], None, ["a"])
# Out indices must be a list or tuple # Out indices must be a list or tuple
with self.assertRaises(ValueError): with pytest.raises(ValueError, match="out_indices must be a list or tuple, got <class 'int'>"):
verify_out_features_out_indices(None, 0, ["a", "b"]) verify_out_features_out_indices(None, 0, ["a", "b"])
# Out indices must be a subset of stage names # Out indices must be a subset of stage names
with self.assertRaises(ValueError): with pytest.raises(
ValueError, match=r"out_indices must be valid indices for stage_names \['a'\], got \(0, 1\)"
):
verify_out_features_out_indices(None, (0, 1), ["a"]) verify_out_features_out_indices(None, (0, 1), ["a"])
# Out indices must contain no duplicates
with pytest.raises(ValueError, match=r"out_indices must not contain any duplicates, got \(0, 0\)"):
verify_out_features_out_indices(None, (0, 0), ["a"])
# Out features and out indices must be the same length # Out features and out indices must be the same length
with self.assertRaises(ValueError): with pytest.raises(
ValueError, match="out_features and out_indices should have the same length if both are set"
):
verify_out_features_out_indices(["a", "b"], (0,), ["a", "b", "c"]) verify_out_features_out_indices(["a", "b"], (0,), ["a", "b", "c"])
# Out features should match out indices # Out features should match out indices
with self.assertRaises(ValueError): with pytest.raises(
ValueError, match="out_features and out_indices should correspond to the same stages if both are set"
):
verify_out_features_out_indices(["a", "b"], (0, 2), ["a", "b", "c"]) verify_out_features_out_indices(["a", "b"], (0, 2), ["a", "b", "c"])
# Out features and out indices should be in order # Out features and out indices should be in order
with self.assertRaises(ValueError): with pytest.raises(
ValueError,
match=r"out_features must be in the same order as stage_names, expected \['a', 'b'\] got \['b', 'a'\]",
):
verify_out_features_out_indices(["b", "a"], (0, 1), ["a", "b"]) verify_out_features_out_indices(["b", "a"], (0, 1), ["a", "b"])
with pytest.raises(
ValueError, match=r"out_indices must be in the same order as stage_names, expected \(-2, 1\) got \(1, -2\)"
):
verify_out_features_out_indices(["a", "b"], (1, -2), ["a", "b"])
# Check passes with valid inputs # Check passes with valid inputs
verify_out_features_out_indices(["a", "b", "d"], (0, 1, -1), ["a", "b", "c", "d"]) verify_out_features_out_indices(["a", "b", "d"], (0, 1, -1), ["a", "b", "c", "d"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment