Unverified Commit 559a45d1 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Backbone add out indices (#22493)

* Add out_indices to backbones, deprecate out_features

* Update - can specify both out_features and out_indices but not both

* Can specify both

* Fix copies

* Add out_indices to convnextv2 configuration
parent db803b69
...@@ -63,7 +63,12 @@ class BitConfig(PretrainedConfig): ...@@ -63,7 +63,12 @@ class BitConfig(PretrainedConfig):
The width factor for the model. The width factor for the model.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
```python ```python
...@@ -98,6 +103,7 @@ class BitConfig(PretrainedConfig): ...@@ -98,6 +103,7 @@ class BitConfig(PretrainedConfig):
output_stride=32, output_stride=32,
width_factor=1, width_factor=1,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -122,6 +128,21 @@ class BitConfig(PretrainedConfig): ...@@ -122,6 +128,21 @@ class BitConfig(PretrainedConfig):
self.width_factor = width_factor self.width_factor = width_factor
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -130,4 +151,12 @@ class BitConfig(PretrainedConfig): ...@@ -130,4 +151,12 @@ class BitConfig(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
...@@ -850,6 +850,10 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin): ...@@ -850,6 +850,10 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embedding_size] + config.hidden_sizes self.num_features = [config.embedding_size] + config.hidden_sizes
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -66,7 +66,12 @@ class ConvNextConfig(PretrainedConfig): ...@@ -66,7 +66,12 @@ class ConvNextConfig(PretrainedConfig):
The drop rate for stochastic depth. The drop rate for stochastic depth.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
```python ```python
...@@ -97,6 +102,7 @@ class ConvNextConfig(PretrainedConfig): ...@@ -97,6 +102,7 @@ class ConvNextConfig(PretrainedConfig):
drop_path_rate=0.0, drop_path_rate=0.0,
image_size=224, image_size=224,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -113,6 +119,21 @@ class ConvNextConfig(PretrainedConfig): ...@@ -113,6 +119,21 @@ class ConvNextConfig(PretrainedConfig):
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.image_size = image_size self.image_size = image_size
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -121,7 +142,15 @@ class ConvNextConfig(PretrainedConfig): ...@@ -121,7 +142,15 @@ class ConvNextConfig(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
class ConvNextOnnxConfig(OnnxConfig): class ConvNextOnnxConfig(OnnxConfig):
......
...@@ -487,6 +487,10 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): ...@@ -487,6 +487,10 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
......
...@@ -58,7 +58,12 @@ class ConvNextV2Config(PretrainedConfig): ...@@ -58,7 +58,12 @@ class ConvNextV2Config(PretrainedConfig):
The drop rate for stochastic depth. The drop rate for stochastic depth.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
```python ```python
...@@ -88,6 +93,7 @@ class ConvNextV2Config(PretrainedConfig): ...@@ -88,6 +93,7 @@ class ConvNextV2Config(PretrainedConfig):
drop_path_rate=0.0, drop_path_rate=0.0,
image_size=224, image_size=224,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -103,6 +109,21 @@ class ConvNextV2Config(PretrainedConfig): ...@@ -103,6 +109,21 @@ class ConvNextV2Config(PretrainedConfig):
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.image_size = image_size self.image_size = image_size
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -111,4 +132,12 @@ class ConvNextV2Config(PretrainedConfig): ...@@ -111,4 +132,12 @@ class ConvNextV2Config(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
...@@ -510,6 +510,10 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin): ...@@ -510,6 +510,10 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
......
...@@ -72,7 +72,12 @@ class DinatConfig(PretrainedConfig): ...@@ -72,7 +72,12 @@ class DinatConfig(PretrainedConfig):
The initial value for the layer scale. Disabled if <=0. The initial value for the layer scale. Disabled if <=0.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
...@@ -114,6 +119,7 @@ class DinatConfig(PretrainedConfig): ...@@ -114,6 +119,7 @@ class DinatConfig(PretrainedConfig):
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
layer_scale_init_value=0.0, layer_scale_init_value=0.0,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -139,6 +145,21 @@ class DinatConfig(PretrainedConfig): ...@@ -139,6 +145,21 @@ class DinatConfig(PretrainedConfig):
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -147,4 +168,12 @@ class DinatConfig(PretrainedConfig): ...@@ -147,4 +168,12 @@ class DinatConfig(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
...@@ -891,6 +891,10 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin): ...@@ -891,6 +891,10 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
self.encoder = DinatEncoder(config) self.encoder = DinatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
......
...@@ -68,7 +68,12 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -68,7 +68,12 @@ class MaskFormerSwinConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
...@@ -110,6 +115,7 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -110,6 +115,7 @@ class MaskFormerSwinConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -135,6 +141,21 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -135,6 +141,21 @@ class MaskFormerSwinConfig(PretrainedConfig):
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -143,4 +164,12 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -143,4 +164,12 @@ class MaskFormerSwinConfig(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
...@@ -859,6 +859,10 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): ...@@ -859,6 +859,10 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
if "stem" in self.out_features: if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.") raise ValueError("This backbone does not support 'stem' in the `out_features`.")
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels]) self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
......
...@@ -70,7 +70,12 @@ class NatConfig(PretrainedConfig): ...@@ -70,7 +70,12 @@ class NatConfig(PretrainedConfig):
The initial value for the layer scale. Disabled if <=0. The initial value for the layer scale. Disabled if <=0.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
...@@ -111,6 +116,7 @@ class NatConfig(PretrainedConfig): ...@@ -111,6 +116,7 @@ class NatConfig(PretrainedConfig):
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
layer_scale_init_value=0.0, layer_scale_init_value=0.0,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -135,6 +141,21 @@ class NatConfig(PretrainedConfig): ...@@ -135,6 +141,21 @@ class NatConfig(PretrainedConfig):
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -143,4 +164,12 @@ class NatConfig(PretrainedConfig): ...@@ -143,4 +164,12 @@ class NatConfig(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
...@@ -869,6 +869,10 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin): ...@@ -869,6 +869,10 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
self.encoder = NatEncoder(config) self.encoder = NatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
......
...@@ -60,7 +60,12 @@ class ResNetConfig(PretrainedConfig): ...@@ -60,7 +60,12 @@ class ResNetConfig(PretrainedConfig):
If `True`, the first stage will downsample the inputs using a `stride` of 2. If `True`, the first stage will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
```python ```python
...@@ -89,6 +94,7 @@ class ResNetConfig(PretrainedConfig): ...@@ -89,6 +94,7 @@ class ResNetConfig(PretrainedConfig):
hidden_act="relu", hidden_act="relu",
downsample_in_first_stage=False, downsample_in_first_stage=False,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -102,6 +108,21 @@ class ResNetConfig(PretrainedConfig): ...@@ -102,6 +108,21 @@ class ResNetConfig(PretrainedConfig):
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage self.downsample_in_first_stage = downsample_in_first_stage
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -110,7 +131,15 @@ class ResNetConfig(PretrainedConfig): ...@@ -110,7 +131,15 @@ class ResNetConfig(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
class ResNetOnnxConfig(OnnxConfig): class ResNetOnnxConfig(OnnxConfig):
......
...@@ -437,6 +437,10 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): ...@@ -437,6 +437,10 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
self.encoder = ResNetEncoder(config) self.encoder = ResNetEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embedding_size] + config.hidden_sizes self.num_features = [config.embedding_size] + config.hidden_sizes
# initialize weights and apply final processing # initialize weights and apply final processing
......
...@@ -83,7 +83,12 @@ class SwinConfig(PretrainedConfig): ...@@ -83,7 +83,12 @@ class SwinConfig(PretrainedConfig):
Factor to increase the spatial resolution by in the decoder head for masked image modeling. Factor to increase the spatial resolution by in the decoder head for masked image modeling.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset. (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
...@@ -126,6 +131,7 @@ class SwinConfig(PretrainedConfig): ...@@ -126,6 +131,7 @@ class SwinConfig(PretrainedConfig):
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
encoder_stride=32, encoder_stride=32,
out_features=None, out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -152,6 +158,21 @@ class SwinConfig(PretrainedConfig): ...@@ -152,6 +158,21 @@ class SwinConfig(PretrainedConfig):
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None: if out_features is not None:
if not isinstance(out_features, list): if not isinstance(out_features, list):
raise ValueError("out_features should be a list") raise ValueError("out_features should be a list")
...@@ -160,7 +181,15 @@ class SwinConfig(PretrainedConfig): ...@@ -160,7 +181,15 @@ class SwinConfig(PretrainedConfig):
raise ValueError( raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}" f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
) )
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
class SwinOnnxConfig(OnnxConfig): class SwinOnnxConfig(OnnxConfig):
......
...@@ -1255,6 +1255,10 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin): ...@@ -1255,6 +1255,10 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
self.encoder = SwinEncoder(config, self.embeddings.patch_grid) self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment