Unverified Commit 90e8263d authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add methods to update and verify out_features out_indices (#23031)

* Add methods to update and verify out_features out_indices

* Safe update for config attributes

* Fix function names

* Save config correctly

* PR comments - use property setters

* PR comment - directly set attributes

* Update test

* Add updates to recently merged focalnet backbone
parent 78b7debf
...@@ -1006,32 +1006,6 @@ class ModuleUtilsMixin: ...@@ -1006,32 +1006,6 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
class BackboneMixin:
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
def forward(
self,
pixel_values: Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r""" r"""
Base class for all models. Base class for all models.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -25,7 +26,7 @@ BIT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -25,7 +26,7 @@ BIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class BitConfig(PretrainedConfig): class BitConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`BitModel`]. It is used to instantiate an BiT This is the configuration class to store the configuration of a [`BitModel`]. It is used to instantiate an BiT
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
...@@ -128,35 +129,6 @@ class BitConfig(PretrainedConfig): ...@@ -128,35 +129,6 @@ class BitConfig(PretrainedConfig):
self.width_factor = width_factor self.width_factor = width_factor
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
...@@ -31,7 +31,7 @@ from ...modeling_outputs import ( ...@@ -31,7 +31,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention, ImageClassifierOutputWithNoAttention,
) )
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -39,6 +39,7 @@ from ...utils import ( ...@@ -39,6 +39,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_bit import BitConfig from .configuration_bit import BitConfig
...@@ -848,12 +849,10 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin): ...@@ -848,12 +849,10 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
self.stage_names = config.stage_names self.stage_names = config.stage_names
self.bit = BitModel(config) self.bit = BitModel(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.embedding_size] + config.hidden_sizes self.num_features = [config.embedding_size] + config.hidden_sizes
if config.out_indices is not None: self._out_features, self._out_indices = get_aligned_output_features_output_indices(
self.out_indices = config.out_indices config.out_features, config.out_indices, self.stage_names
else: )
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -32,7 +33,7 @@ CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -32,7 +33,7 @@ CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class ConvNextConfig(PretrainedConfig): class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
...@@ -119,38 +120,9 @@ class ConvNextConfig(PretrainedConfig): ...@@ -119,38 +120,9 @@ class ConvNextConfig(PretrainedConfig):
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.image_size = image_size self.image_size = image_size
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
class ConvNextOnnxConfig(OnnxConfig): class ConvNextOnnxConfig(OnnxConfig):
......
...@@ -29,7 +29,7 @@ from ...modeling_outputs import ( ...@@ -29,7 +29,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention, ImageClassifierOutputWithNoAttention,
) )
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -37,6 +37,7 @@ from ...utils import ( ...@@ -37,6 +37,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_convnext import ConvNextConfig from .configuration_convnext import ConvNextConfig
...@@ -485,16 +486,14 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): ...@@ -485,16 +486,14 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
self.embeddings = ConvNextEmbeddings(config) self.embeddings = ConvNextEmbeddings(config)
self.encoder = ConvNextEncoder(config) self.encoder = ConvNextEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
if config.out_indices is not None: self._out_features, self._out_indices = get_aligned_output_features_output_indices(
self.out_indices = config.out_indices config.out_features, config.out_indices, self.stage_names
else: )
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels): for stage, num_channels in zip(self._out_features, self.channels):
hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first") hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -26,7 +27,7 @@ CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -26,7 +27,7 @@ CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class ConvNextV2Config(PretrainedConfig): class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`ConvNextV2Model`]. It is used to instantiate an This is the configuration class to store the configuration of a [`ConvNextV2Model`]. It is used to instantiate an
ConvNeXTV2 model according to the specified arguments, defining the model architecture. Instantiating a ConvNeXTV2 model according to the specified arguments, defining the model architecture. Instantiating a
...@@ -109,35 +110,6 @@ class ConvNextV2Config(PretrainedConfig): ...@@ -109,35 +110,6 @@ class ConvNextV2Config(PretrainedConfig):
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.image_size = image_size self.image_size = image_size
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
...@@ -29,7 +29,7 @@ from ...modeling_outputs import ( ...@@ -29,7 +29,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention, ImageClassifierOutputWithNoAttention,
) )
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -37,6 +37,7 @@ from ...utils import ( ...@@ -37,6 +37,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_convnextv2 import ConvNextV2Config from .configuration_convnextv2 import ConvNextV2Config
...@@ -508,16 +509,14 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin): ...@@ -508,16 +509,14 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
self.embeddings = ConvNextV2Embeddings(config) self.embeddings = ConvNextV2Embeddings(config)
self.encoder = ConvNextV2Encoder(config) self.encoder = ConvNextV2Encoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
if config.out_indices is not None: self._out_features, self._out_indices = get_aligned_output_features_output_indices(
self.out_indices = config.out_indices config.out_features, config.out_indices, self.stage_names
else: )
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels): for stage, num_channels in zip(self._out_features, self.channels):
hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first") hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first")
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -26,7 +27,7 @@ DINAT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -26,7 +27,7 @@ DINAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class DinatConfig(PretrainedConfig): class DinatConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
...@@ -145,35 +146,6 @@ class DinatConfig(PretrainedConfig): ...@@ -145,35 +146,6 @@ class DinatConfig(PretrainedConfig):
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
...@@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -39,6 +39,7 @@ from ...utils import ( ...@@ -39,6 +39,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_dinat import DinatConfig from .configuration_dinat import DinatConfig
...@@ -890,16 +891,14 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin): ...@@ -890,16 +891,14 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
self.embeddings = DinatEmbeddings(config) self.embeddings = DinatEmbeddings(config)
self.encoder = DinatEncoder(config) self.encoder = DinatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if config.out_indices is not None: config.out_features, config.out_indices, self.stage_names
self.out_indices = config.out_indices )
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels): for stage, num_channels in zip(self._out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels) hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -25,7 +26,7 @@ FOCALNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -25,7 +26,7 @@ FOCALNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class FocalNetConfig(PretrainedConfig): class FocalNetConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`FocalNetModel`]. It is used to instantiate a This is the configuration class to store the configuration of a [`FocalNetModel`]. It is used to instantiate a
FocalNet model according to the specified arguments, defining the model architecture. Instantiating a configuration FocalNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
...@@ -156,35 +157,6 @@ class FocalNetConfig(PretrainedConfig): ...@@ -156,35 +157,6 @@ class FocalNetConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
...@@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -36,6 +36,7 @@ from ...utils import ( ...@@ -36,6 +36,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_focalnet import FocalNetConfig from .configuration_focalnet import FocalNetConfig
...@@ -987,11 +988,9 @@ class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin): ...@@ -987,11 +988,9 @@ class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
self.focalnet = FocalNetModel(config) self.focalnet = FocalNetModel(config)
self.num_features = [config.embed_dim] + config.hidden_sizes self.num_features = [config.embed_dim] + config.hidden_sizes
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if config.out_indices is not None: config.out_features, config.out_indices, self.stage_names
self.out_indices = config.out_indices )
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class MaskFormerSwinConfig(PretrainedConfig): class MaskFormerSwinConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate
a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
...@@ -141,35 +142,6 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -141,35 +142,6 @@ class MaskFormerSwinConfig(PretrainedConfig):
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
...@@ -27,8 +27,9 @@ from torch import Tensor, nn ...@@ -27,8 +27,9 @@ from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ModelOutput from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_maskformer_swin import MaskFormerSwinConfig from .configuration_maskformer_swin import MaskFormerSwinConfig
...@@ -855,14 +856,13 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): ...@@ -855,14 +856,13 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
self.stage_names = config.stage_names self.stage_names = config.stage_names
self.model = MaskFormerSwinModel(config) self.model = MaskFormerSwinModel(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if "stem" in self.out_features: if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.") raise ValueError("This backbone does not support 'stem' in the `out_features`.")
if config.out_indices is not None: self._out_features, self._out_indices = get_aligned_output_features_output_indices(
self.out_indices = config.out_indices config.out_features, config.out_indices, self.stage_names
else: )
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.hidden_states_norms = nn.ModuleList( self.hidden_states_norms = nn.ModuleList(
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]] [nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -26,7 +27,7 @@ NAT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -26,7 +27,7 @@ NAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class NatConfig(PretrainedConfig): class NatConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model
according to the specified arguments, defining the model architecture. Instantiating a configuration with the according to the specified arguments, defining the model architecture. Instantiating a configuration with the
...@@ -141,35 +142,6 @@ class NatConfig(PretrainedConfig): ...@@ -141,35 +142,6 @@ class NatConfig(PretrainedConfig):
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
...@@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -39,6 +39,7 @@ from ...utils import ( ...@@ -39,6 +39,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_nat import NatConfig from .configuration_nat import NatConfig
...@@ -868,11 +869,9 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin): ...@@ -868,11 +869,9 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
self.embeddings = NatEmbeddings(config) self.embeddings = NatEmbeddings(config)
self.encoder = NatEncoder(config) self.encoder = NatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if config.out_indices is not None: config.out_features, config.out_indices, self.stage_names
self.out_indices = config.out_indices )
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -31,7 +32,7 @@ RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -31,7 +32,7 @@ RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class ResNetConfig(PretrainedConfig): class ResNetConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an
ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
...@@ -108,38 +109,9 @@ class ResNetConfig(PretrainedConfig): ...@@ -108,38 +109,9 @@ class ResNetConfig(PretrainedConfig):
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage self.downsample_in_first_stage = downsample_in_first_stage
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
class ResNetOnnxConfig(OnnxConfig): class ResNetOnnxConfig(OnnxConfig):
......
...@@ -28,7 +28,7 @@ from ...modeling_outputs import ( ...@@ -28,7 +28,7 @@ from ...modeling_outputs import (
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention, ImageClassifierOutputWithNoAttention,
) )
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -36,6 +36,7 @@ from ...utils import ( ...@@ -36,6 +36,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_resnet import ResNetConfig from .configuration_resnet import ResNetConfig
...@@ -436,11 +437,9 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): ...@@ -436,11 +437,9 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
self.embedder = ResNetEmbeddings(config) self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config) self.encoder = ResNetEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if config.out_indices is not None: config.out_features, config.out_indices, self.stage_names
self.out_indices = config.out_indices )
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embedding_size] + config.hidden_sizes self.num_features = [config.embedding_size] + config.hidden_sizes
# initialize weights and apply final processing # initialize weights and apply final processing
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -34,7 +35,7 @@ SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -34,7 +35,7 @@ SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
} }
class SwinConfig(PretrainedConfig): class SwinConfig(BackboneConfigMixin, PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`SwinModel`]. It is used to instantiate a Swin This is the configuration class to store the configuration of a [`SwinModel`]. It is used to instantiate a Swin
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
...@@ -158,38 +159,9 @@ class SwinConfig(PretrainedConfig): ...@@ -158,38 +159,9 @@ class SwinConfig(PretrainedConfig):
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if out_features is not None and out_indices is not None: out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
if len(out_features) != len(out_indices): )
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
class SwinOnnxConfig(OnnxConfig): class SwinOnnxConfig(OnnxConfig):
......
...@@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -38,6 +38,7 @@ from ...utils import ( ...@@ -38,6 +38,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from .configuration_swin import SwinConfig from .configuration_swin import SwinConfig
...@@ -1264,16 +1265,14 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin): ...@@ -1264,16 +1265,14 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
self.embeddings = SwinEmbeddings(config) self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(config, self.embeddings.patch_grid) self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] self._out_features, self._out_indices = get_aligned_output_features_output_indices(
if config.out_indices is not None: config.out_features, config.out_indices, self.stage_names
self.out_indices = config.out_indices )
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features # Add layer norms to hidden states of out_features
hidden_states_norms = {} hidden_states_norms = {}
for stage, num_channels in zip(self.out_features, self.channels): for stage, num_channels in zip(self._out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels) hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
......
...@@ -22,8 +22,9 @@ from torch.nn import CrossEntropyLoss ...@@ -22,8 +22,9 @@ from torch.nn import CrossEntropyLoss
from ... import AutoBackbone from ... import AutoBackbone
from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_outputs import SemanticSegmenterOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...utils.backbone_utils import BackboneMixin
from .configuration_upernet import UperNetConfig from .configuration_upernet import UperNetConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment