Unverified Commit 9a6c6ef9 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Backbones] Improve out features (#20675)



* Improve ResNet backbone

* Improve Bit backbone

* Improve docstrings

* Fix default stage

* Apply suggestions from code review
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 9e56aff5
...@@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig): ...@@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig):
The width factor for the model. The width factor for the model.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). (depending on how many stages the model has). Will default to the last stage if unset.
Example: Example:
```python ```python
......
...@@ -851,7 +851,7 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin): ...@@ -851,7 +851,7 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
self.stage_names = config.stage_names self.stage_names = config.stage_names
self.bit = BitModel(config) self.bit = BitModel(config)
self.out_features = config.out_features self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {} out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size out_feature_channels["stem"] = config.embedding_size
......
...@@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig):
layer_norm_eps (`float`, *optional*, defaults to 1e-12): layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`. If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example: Example:
......
...@@ -855,7 +855,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): ...@@ -855,7 +855,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
self.stage_names = config.stage_names self.stage_names = config.stage_names
self.model = MaskFormerSwinModel(config) self.model = MaskFormerSwinModel(config)
self.out_features = config.out_features self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if "stem" in self.out_features: if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.") raise ValueError("This backbone does not support 'stem' in the `out_features`.")
......
...@@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig): ...@@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig):
downsample_in_first_stage (`bool`, *optional*, defaults to `False`): downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2. If `True`, the first stage will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
`"stage3"`, `"stage4"`. (depending on how many stages the model has). Will default to the last stage if unset.
Example: Example:
```python ```python
......
...@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel): ...@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (ResNetModel, ResNetBackbone)): if isinstance(module, ResNetEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -439,7 +439,7 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): ...@@ -439,7 +439,7 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
self.embedder = ResNetEmbeddings(config) self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config) self.encoder = ResNetEncoder(config)
self.out_features = config.out_features self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {} out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size out_feature_channels["stem"] = config.embedding_size
......
...@@ -119,7 +119,7 @@ class BitModelTester: ...@@ -119,7 +119,7 @@ class BitModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# verify hidden states # verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
...@@ -127,6 +127,21 @@ class BitModelTester: ...@@ -127,6 +127,21 @@ class BitModelTester:
self.parent.assertEqual(len(model.channels), len(config.out_features)) self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
# verify backbone works with out_features=None
config.out_features = None
model = BitBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs config, pixel_values, labels = config_and_inputs
......
...@@ -119,7 +119,7 @@ class ResNetModelTester: ...@@ -119,7 +119,7 @@ class ResNetModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
# verify hidden states # verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4]) self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
...@@ -127,6 +127,21 @@ class ResNetModelTester: ...@@ -127,6 +127,21 @@ class ResNetModelTester:
self.parent.assertEqual(len(model.channels), len(config.out_features)) self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:]) self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
# verify backbone works with out_features=None
config.out_features = None
model = ResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs config, pixel_values, labels = config_and_inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment