Unverified Commit c15f9375 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Move common properties to BackboneMixin (#21855)

* Move common properties to BackboneMixin

* Fix failing tests

* Update ConvNextV2 backbone
parent cd73b9a8
...@@ -968,12 +968,30 @@ class ModuleUtilsMixin: ...@@ -968,12 +968,30 @@ class ModuleUtilsMixin:
class BackboneMixin: class BackboneMixin:
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward_with_filtered_kwargs(self, *args, **kwargs): def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters) signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature} filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs) return self(*args, **filtered_kwargs)
def forward(
self,
pixel_values: Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r""" r"""
......
...@@ -849,21 +849,11 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin): ...@@ -849,21 +849,11 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
self.bit = BitModel(config) self.bit = BitModel(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embedding_size] + config.hidden_sizes
out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(BIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -486,13 +486,7 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): ...@@ -486,13 +486,7 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
self.encoder = ConvNextEncoder(config) self.encoder = ConvNextEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
out_feature_channels = {}
out_feature_channels["stem"] = config.hidden_sizes[0]
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
...@@ -503,10 +497,6 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): ...@@ -503,10 +497,6 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -509,13 +509,7 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin): ...@@ -509,13 +509,7 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
self.encoder = ConvNextV2Encoder(config) self.encoder = ConvNextV2Encoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
out_feature_channels = {}
out_feature_channels["stem"] = config.hidden_sizes[0]
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
...@@ -526,10 +520,6 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin): ...@@ -526,10 +520,6 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -891,12 +891,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin): ...@@ -891,12 +891,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
self.encoder = DinatEncoder(config) self.encoder = DinatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
...@@ -910,10 +905,6 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin): ...@@ -910,10 +905,6 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -859,20 +859,12 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): ...@@ -859,20 +859,12 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
if "stem" in self.out_features: if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.") raise ValueError("This backbone does not support 'stem' in the `out_features`.")
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels]) self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward( def forward(
self, self,
pixel_values: Tensor, pixel_values: Tensor,
......
...@@ -869,12 +869,7 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin): ...@@ -869,12 +869,7 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
self.encoder = NatEncoder(config) self.encoder = NatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
...@@ -888,10 +883,6 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin): ...@@ -888,10 +883,6 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -437,21 +437,11 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): ...@@ -437,21 +437,11 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
self.encoder = ResNetEncoder(config) self.encoder = ResNetEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embedding_size] + config.hidden_sizes
out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
......
...@@ -1255,12 +1255,7 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin): ...@@ -1255,12 +1255,7 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
self.encoder = SwinEncoder(config, self.embeddings.patch_grid) self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
...@@ -1274,10 +1269,6 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin): ...@@ -1274,10 +1269,6 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward( def forward(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment