Unverified Commit 6ef42587 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[NAT, DiNAT] Add backbone class (#20654)



* Add first draft

* Add out_features attribute to config

* Add corresponding test

* Add Dinat backbone

* Add BackboneMixin

* Add Backbone mixin, improve tests

* Fix embeddings

* Fix bug

* Improve backbones

* Fix Nat backbone tests

* Fix Dinat backbone tests

* Apply suggestions
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 30d8919a
...@@ -1274,6 +1274,7 @@ else: ...@@ -1274,6 +1274,7 @@ else:
_import_structure["models.dinat"].extend( _import_structure["models.dinat"].extend(
[ [
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST", "DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"DinatBackbone",
"DinatForImageClassification", "DinatForImageClassification",
"DinatModel", "DinatModel",
"DinatPreTrainedModel", "DinatPreTrainedModel",
...@@ -1769,6 +1770,7 @@ else: ...@@ -1769,6 +1770,7 @@ else:
_import_structure["models.nat"].extend( _import_structure["models.nat"].extend(
[ [
"NAT_PRETRAINED_MODEL_ARCHIVE_LIST", "NAT_PRETRAINED_MODEL_ARCHIVE_LIST",
"NatBackbone",
"NatForImageClassification", "NatForImageClassification",
"NatModel", "NatModel",
"NatPreTrainedModel", "NatPreTrainedModel",
...@@ -4388,6 +4390,7 @@ if TYPE_CHECKING: ...@@ -4388,6 +4390,7 @@ if TYPE_CHECKING:
) )
from .models.dinat import ( from .models.dinat import (
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST, DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
DinatBackbone,
DinatForImageClassification, DinatForImageClassification,
DinatModel, DinatModel,
DinatPreTrainedModel, DinatPreTrainedModel,
...@@ -4784,6 +4787,7 @@ if TYPE_CHECKING: ...@@ -4784,6 +4787,7 @@ if TYPE_CHECKING:
) )
from .models.nat import ( from .models.nat import (
NAT_PRETRAINED_MODEL_ARCHIVE_LIST, NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
NatBackbone,
NatForImageClassification, NatForImageClassification,
NatModel, NatModel,
NatPreTrainedModel, NatPreTrainedModel,
......
...@@ -865,7 +865,9 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ...@@ -865,7 +865,9 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
[ [
# Backbone mapping # Backbone mapping
("bit", "BitBackbone"), ("bit", "BitBackbone"),
("dinat", "DinatBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"), ("resnet", "ResNetBackbone"),
] ]
) )
......
...@@ -35,6 +35,7 @@ else: ...@@ -35,6 +35,7 @@ else:
"DinatForImageClassification", "DinatForImageClassification",
"DinatModel", "DinatModel",
"DinatPreTrainedModel", "DinatPreTrainedModel",
"DinatBackbone",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -48,6 +49,7 @@ if TYPE_CHECKING: ...@@ -48,6 +49,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_dinat import ( from .modeling_dinat import (
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST, DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
DinatBackbone,
DinatForImageClassification, DinatForImageClassification,
DinatModel, DinatModel,
DinatPreTrainedModel, DinatPreTrainedModel,
......
...@@ -72,6 +72,9 @@ class DinatConfig(PretrainedConfig): ...@@ -72,6 +72,9 @@ class DinatConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0): layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0. The initial value for the layer scale. Disabled if <=0.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example: Example:
...@@ -113,6 +116,7 @@ class DinatConfig(PretrainedConfig): ...@@ -113,6 +116,7 @@ class DinatConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
layer_scale_init_value=0.0, layer_scale_init_value=0.0,
out_features=None,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -138,3 +142,13 @@ class DinatConfig(PretrainedConfig): ...@@ -138,3 +142,13 @@ class DinatConfig(PretrainedConfig):
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features
...@@ -25,7 +25,8 @@ from torch import nn ...@@ -25,7 +25,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -35,6 +36,7 @@ from ...utils import ( ...@@ -35,6 +36,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_natten_available, is_natten_available,
logging, logging,
replace_return_docstrings,
requires_backends, requires_backends,
) )
from .configuration_dinat import DinatConfig from .configuration_dinat import DinatConfig
...@@ -555,14 +557,11 @@ class DinatStage(nn.Module): ...@@ -555,14 +557,11 @@ class DinatStage(nn.Module):
layer_outputs = layer_module(hidden_states, output_attentions) layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None: if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 hidden_states = self.downsample(hidden_states_before_downsampling)
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0])
else:
output_dimensions = (height, width, height, width)
stage_outputs = (hidden_states, output_dimensions) stage_outputs = (hidden_states, hidden_states_before_downsampling)
if output_attentions: if output_attentions:
stage_outputs += layer_outputs[1:] stage_outputs += layer_outputs[1:]
...@@ -596,6 +595,7 @@ class DinatEncoder(nn.Module): ...@@ -596,6 +595,7 @@ class DinatEncoder(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[Tuple, DinatEncoderOutput]: ) -> Union[Tuple, DinatEncoderOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -612,8 +612,14 @@ class DinatEncoder(nn.Module): ...@@ -612,8 +612,14 @@ class DinatEncoder(nn.Module):
layer_outputs = layer_module(hidden_states, output_attentions) layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
if output_hidden_states: if output_hidden_states and output_hidden_states_before_downsampling:
# rearrange b h w c -> b c h w
reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
# rearrange b h w c -> b c h w # rearrange b h w c -> b c h w
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -871,3 +877,120 @@ class DinatForImageClassification(DinatPreTrainedModel): ...@@ -871,3 +877,120 @@ class DinatForImageClassification(DinatPreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states,
) )
@add_start_docstrings(
"NAT backbone, to be used with frameworks like DETR and MaskFormer.",
DINAT_START_DOCSTRING,
)
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
requires_backends(self, ["natten"])
self.stage_names = config.stage_names
self.embeddings = DinatEmbeddings(config)
self.encoder = DinatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features
hidden_states_norms = dict()
for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "shi-labs/nat-mini-in1k-2240", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
embedding_output = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=True,
)
hidden_states = outputs.reshaped_hidden_states
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -35,6 +35,7 @@ else: ...@@ -35,6 +35,7 @@ else:
"NatForImageClassification", "NatForImageClassification",
"NatModel", "NatModel",
"NatPreTrainedModel", "NatPreTrainedModel",
"NatBackbone",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -48,6 +49,7 @@ if TYPE_CHECKING: ...@@ -48,6 +49,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_nat import ( from .modeling_nat import (
NAT_PRETRAINED_MODEL_ARCHIVE_LIST, NAT_PRETRAINED_MODEL_ARCHIVE_LIST,
NatBackbone,
NatForImageClassification, NatForImageClassification,
NatModel, NatModel,
NatPreTrainedModel, NatPreTrainedModel,
......
...@@ -70,6 +70,9 @@ class NatConfig(PretrainedConfig): ...@@ -70,6 +70,9 @@ class NatConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
layer_scale_init_value (`float`, *optional*, defaults to 0.0): layer_scale_init_value (`float`, *optional*, defaults to 0.0):
The initial value for the layer scale. Disabled if <=0. The initial value for the layer scale. Disabled if <=0.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example: Example:
...@@ -110,6 +113,7 @@ class NatConfig(PretrainedConfig): ...@@ -110,6 +113,7 @@ class NatConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
layer_scale_init_value=0.0, layer_scale_init_value=0.0,
out_features=None,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -134,3 +138,13 @@ class NatConfig(PretrainedConfig): ...@@ -134,3 +138,13 @@ class NatConfig(PretrainedConfig):
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features
...@@ -25,7 +25,8 @@ from torch import nn ...@@ -25,7 +25,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -35,6 +36,7 @@ from ...utils import ( ...@@ -35,6 +36,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_natten_available, is_natten_available,
logging, logging,
replace_return_docstrings,
requires_backends, requires_backends,
) )
from .configuration_nat import NatConfig from .configuration_nat import NatConfig
...@@ -536,14 +538,11 @@ class NatStage(nn.Module): ...@@ -536,14 +538,11 @@ class NatStage(nn.Module):
layer_outputs = layer_module(hidden_states, output_attentions) layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None: if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 hidden_states = self.downsample(hidden_states_before_downsampling)
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0])
else:
output_dimensions = (height, width, height, width)
stage_outputs = (hidden_states, output_dimensions) stage_outputs = (hidden_states, hidden_states_before_downsampling)
if output_attentions: if output_attentions:
stage_outputs += layer_outputs[1:] stage_outputs += layer_outputs[1:]
...@@ -575,6 +574,7 @@ class NatEncoder(nn.Module): ...@@ -575,6 +574,7 @@ class NatEncoder(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[Tuple, NatEncoderOutput]: ) -> Union[Tuple, NatEncoderOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -591,8 +591,14 @@ class NatEncoder(nn.Module): ...@@ -591,8 +591,14 @@ class NatEncoder(nn.Module):
layer_outputs = layer_module(hidden_states, output_attentions) layer_outputs = layer_module(hidden_states, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = layer_outputs[1]
if output_hidden_states: if output_hidden_states and output_hidden_states_before_downsampling:
# rearrange b h w c -> b c h w
reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
# rearrange b h w c -> b c h w # rearrange b h w c -> b c h w
reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
...@@ -849,3 +855,121 @@ class NatForImageClassification(NatPreTrainedModel): ...@@ -849,3 +855,121 @@ class NatForImageClassification(NatPreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states,
) )
@add_start_docstrings(
"NAT backbone, to be used with frameworks like DETR and MaskFormer.",
NAT_START_DOCSTRING,
)
class NatBackbone(NatPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
requires_backends(self, ["natten"])
self.stage_names = config.stage_names
self.embeddings = NatEmbeddings(config)
self.encoder = NatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features
hidden_states_norms = dict()
for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "shi-labs/nat-mini-in1k-2240", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
embedding_output = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=True,
)
hidden_states = outputs.reshaped_hidden_states
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
# TODO can we simplify this?
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -1856,6 +1856,13 @@ class DeiTPreTrainedModel(metaclass=DummyObject): ...@@ -1856,6 +1856,13 @@ class DeiTPreTrainedModel(metaclass=DummyObject):
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DinatBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DinatForImageClassification(metaclass=DummyObject): class DinatForImageClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -3913,6 +3920,13 @@ class MvpPreTrainedModel(metaclass=DummyObject): ...@@ -3913,6 +3920,13 @@ class MvpPreTrainedModel(metaclass=DummyObject):
NAT_PRETRAINED_MODEL_ARCHIVE_LIST = None NAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class NatBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class NatForImageClassification(metaclass=DummyObject): class NatForImageClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -30,7 +30,7 @@ if is_torch_available(): ...@@ -30,7 +30,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import DinatForImageClassification, DinatModel from transformers import DinatBackbone, DinatForImageClassification, DinatModel
from transformers.models.dinat.modeling_dinat import DINAT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.dinat.modeling_dinat import DINAT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -64,8 +64,8 @@ class DinatModelTester: ...@@ -64,8 +64,8 @@ class DinatModelTester:
is_training=True, is_training=True,
scope=None, scope=None,
use_labels=True, use_labels=True,
type_sequence_label_size=10, num_labels=10,
encoder_stride=8, out_features=["stage1", "stage2"],
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -89,15 +89,15 @@ class DinatModelTester: ...@@ -89,15 +89,15 @@ class DinatModelTester:
self.is_training = is_training self.is_training = is_training
self.scope = scope self.scope = scope
self.use_labels = use_labels self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size self.num_labels = num_labels
self.encoder_stride = encoder_stride self.out_features = out_features
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None labels = None
if self.use_labels: if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size) labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config() config = self.get_config()
...@@ -105,6 +105,7 @@ class DinatModelTester: ...@@ -105,6 +105,7 @@ class DinatModelTester:
def get_config(self): def get_config(self):
return DinatConfig( return DinatConfig(
num_labels=self.num_labels,
image_size=self.image_size, image_size=self.image_size,
patch_size=self.patch_size, patch_size=self.patch_size,
num_channels=self.num_channels, num_channels=self.num_channels,
...@@ -122,7 +123,7 @@ class DinatModelTester: ...@@ -122,7 +123,7 @@ class DinatModelTester:
patch_norm=self.patch_norm, patch_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps, layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, out_features=self.out_features,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -139,12 +140,11 @@ class DinatModelTester: ...@@ -139,12 +140,11 @@ class DinatModelTester:
) )
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = DinatForImageClassification(config) model = DinatForImageClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
# test greyscale images # test greyscale images
config.num_channels = 1 config.num_channels = 1
...@@ -154,7 +154,34 @@ class DinatModelTester: ...@@ -154,7 +154,34 @@ class DinatModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_backbone(self, config, pixel_values, labels):
model = DinatBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = DinatBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
...@@ -167,7 +194,15 @@ class DinatModelTester: ...@@ -167,7 +194,15 @@ class DinatModelTester:
@require_torch @require_torch
class DinatModelTest(ModelTesterMixin, unittest.TestCase): class DinatModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (DinatModel, DinatForImageClassification) if is_torch_available() else () all_model_classes = (
(
DinatModel,
DinatForImageClassification,
DinatBackbone,
)
if is_torch_available()
else ()
)
fx_compatible = False fx_compatible = False
test_torchscript = False test_torchscript = False
...@@ -199,8 +234,16 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -199,8 +234,16 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
@unittest.skip(reason="Dinat does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
# Dinat does not use inputs_embeds pass
@unittest.skip(reason="Dinat does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
...@@ -257,17 +300,18 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -257,17 +300,18 @@ class DinatModelTest(ModelTesterMixin, unittest.TestCase):
[height, width, self.model_tester.embed_dim], [height, width, self.model_tester.embed_dim],
) )
reshaped_hidden_states = outputs.reshaped_hidden_states if model_class.__name__ != "DinatBackbone":
self.assertEqual(len(reshaped_hidden_states), expected_num_layers) reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = ( batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1) reshaped_hidden_states = (
) reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
self.assertListEqual( )
list(reshaped_hidden_states.shape[-3:]), self.assertListEqual(
[height, width, self.model_tester.embed_dim], list(reshaped_hidden_states.shape[-3:]),
) [height, width, self.model_tester.embed_dim],
)
def test_hidden_states_output(self): def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -30,7 +30,7 @@ if is_torch_available(): ...@@ -30,7 +30,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import NatForImageClassification, NatModel from transformers import NatBackbone, NatForImageClassification, NatModel
from transformers.models.nat.modeling_nat import NAT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.nat.modeling_nat import NAT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -63,8 +63,8 @@ class NatModelTester: ...@@ -63,8 +63,8 @@ class NatModelTester:
is_training=True, is_training=True,
scope=None, scope=None,
use_labels=True, use_labels=True,
type_sequence_label_size=10, num_labels=10,
encoder_stride=8, out_features=["stage1", "stage2"],
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -87,15 +87,15 @@ class NatModelTester: ...@@ -87,15 +87,15 @@ class NatModelTester:
self.is_training = is_training self.is_training = is_training
self.scope = scope self.scope = scope
self.use_labels = use_labels self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size self.num_labels = num_labels
self.encoder_stride = encoder_stride self.out_features = out_features
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None labels = None
if self.use_labels: if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size) labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config() config = self.get_config()
...@@ -103,6 +103,7 @@ class NatModelTester: ...@@ -103,6 +103,7 @@ class NatModelTester:
def get_config(self): def get_config(self):
return NatConfig( return NatConfig(
num_labels=self.num_labels,
image_size=self.image_size, image_size=self.image_size,
patch_size=self.patch_size, patch_size=self.patch_size,
num_channels=self.num_channels, num_channels=self.num_channels,
...@@ -119,7 +120,7 @@ class NatModelTester: ...@@ -119,7 +120,7 @@ class NatModelTester:
patch_norm=self.patch_norm, patch_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps, layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, out_features=self.out_features,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -136,12 +137,11 @@ class NatModelTester: ...@@ -136,12 +137,11 @@ class NatModelTester:
) )
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = NatForImageClassification(config) model = NatForImageClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
# test greyscale images # test greyscale images
config.num_channels = 1 config.num_channels = 1
...@@ -151,7 +151,34 @@ class NatModelTester: ...@@ -151,7 +151,34 @@ class NatModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_backbone(self, config, pixel_values, labels):
model = NatBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = NatBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
...@@ -164,7 +191,15 @@ class NatModelTester: ...@@ -164,7 +191,15 @@ class NatModelTester:
@require_torch @require_torch
class NatModelTest(ModelTesterMixin, unittest.TestCase): class NatModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (NatModel, NatForImageClassification) if is_torch_available() else () all_model_classes = (
(
NatModel,
NatForImageClassification,
NatBackbone,
)
if is_torch_available()
else ()
)
fx_compatible = False fx_compatible = False
test_torchscript = False test_torchscript = False
...@@ -196,8 +231,16 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -196,8 +231,16 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
@unittest.skip(reason="Nat does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
# Nat does not use inputs_embeds pass
@unittest.skip(reason="Nat does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
...@@ -254,17 +297,18 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -254,17 +297,18 @@ class NatModelTest(ModelTesterMixin, unittest.TestCase):
[height, width, self.model_tester.embed_dim], [height, width, self.model_tester.embed_dim],
) )
reshaped_hidden_states = outputs.reshaped_hidden_states if model_class.__name__ != "NatBackbone":
self.assertEqual(len(reshaped_hidden_states), expected_num_layers) reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = ( batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1) reshaped_hidden_states = (
) reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)
self.assertListEqual( )
list(reshaped_hidden_states.shape[-3:]), self.assertListEqual(
[height, width, self.model_tester.embed_dim], list(reshaped_hidden_states.shape[-3:]),
) [height, width, self.model_tester.embed_dim],
)
def test_hidden_states_output(self): def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -677,6 +677,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -677,6 +677,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"MaskFormerSwinBackbone", "MaskFormerSwinBackbone",
"ResNetBackbone", "ResNetBackbone",
"AutoBackbone", "AutoBackbone",
"DinatBackbone",
"NatBackbone",
"MaskFormerSwinConfig", "MaskFormerSwinConfig",
"MaskFormerSwinModel", "MaskFormerSwinModel",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment