Unverified Commit 6b217c52 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add AutoBackbone + ResNetBackbone (#20229)



* Add ResNetBackbone

* Define channels and strides as property

* Remove file

* Add test for backbone

* Update BackboneOutput class

* Remove strides property

* Fix docstring

* Add backbones to SHOULD_HAVE_THEIR_OWN_PAGE

* Fix auto mapping name

* Add sanity check for out_features

* Set stage names based on depths

* Update to tuple
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 904ac210
...@@ -920,6 +920,7 @@ else: ...@@ -920,6 +920,7 @@ else:
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel", "AutoModel",
"AutoBackbone",
"AutoModelForAudioClassification", "AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification", "AutoModelForAudioFrameClassification",
"AutoModelForAudioXVector", "AutoModelForAudioXVector",
...@@ -1877,6 +1878,7 @@ else: ...@@ -1877,6 +1878,7 @@ else:
"ResNetForImageClassification", "ResNetForImageClassification",
"ResNetModel", "ResNetModel",
"ResNetPreTrainedModel", "ResNetPreTrainedModel",
"ResNetBackbone",
] ]
) )
_import_structure["models.retribert"].extend( _import_structure["models.retribert"].extend(
...@@ -3946,6 +3948,7 @@ if TYPE_CHECKING: ...@@ -3946,6 +3948,7 @@ if TYPE_CHECKING:
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoBackbone,
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForAudioFrameClassification, AutoModelForAudioFrameClassification,
...@@ -4730,6 +4733,7 @@ if TYPE_CHECKING: ...@@ -4730,6 +4733,7 @@ if TYPE_CHECKING:
) )
from .models.resnet import ( from .models.resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetBackbone,
ResNetForImageClassification, ResNetForImageClassification,
ResNetModel, ResNetModel,
ResNetPreTrainedModel, ResNetPreTrainedModel,
......
...@@ -1263,3 +1263,16 @@ class XVectorOutput(ModelOutput): ...@@ -1263,3 +1263,16 @@ class XVectorOutput(ModelOutput):
embeddings: torch.FloatTensor = None embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BackboneOutput(ModelOutput):
"""
Base class for outputs of backbones.
Args:
feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
Feature maps of the stages.
"""
feature_maps: Tuple[torch.FloatTensor] = None
...@@ -73,6 +73,7 @@ else: ...@@ -73,6 +73,7 @@ else:
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING",
"AutoModel", "AutoModel",
"AutoBackbone",
"AutoModelForAudioClassification", "AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification", "AutoModelForAudioFrameClassification",
"AutoModelForAudioXVector", "AutoModelForAudioXVector",
...@@ -225,6 +226,7 @@ if TYPE_CHECKING: ...@@ -225,6 +226,7 @@ if TYPE_CHECKING:
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoBackbone,
AutoModel, AutoModel,
AutoModelForAudioClassification, AutoModelForAudioClassification,
AutoModelForAudioFrameClassification, AutoModelForAudioFrameClassification,
......
...@@ -836,6 +836,13 @@ _MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -836,6 +836,13 @@ _MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
] ]
) )
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
[
# Backbone mapping
("resnet", "ResNetBackbone"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
...@@ -903,6 +910,8 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( ...@@ -903,6 +910,8 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
) )
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
class AutoModel(_BaseAutoModelClass): class AutoModel(_BaseAutoModelClass):
_model_mapping = MODEL_MAPPING _model_mapping = MODEL_MAPPING
...@@ -1126,6 +1135,10 @@ class AutoModelForAudioXVector(_BaseAutoModelClass): ...@@ -1126,6 +1135,10 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
class AutoBackbone(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
......
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
"ResNetForImageClassification", "ResNetForImageClassification",
"ResNetModel", "ResNetModel",
"ResNetPreTrainedModel", "ResNetPreTrainedModel",
"ResNetBackbone",
] ]
try: try:
...@@ -63,6 +64,7 @@ if TYPE_CHECKING: ...@@ -63,6 +64,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_resnet import ( from .modeling_resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST, RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetBackbone,
ResNetForImageClassification, ResNetForImageClassification,
ResNetModel, ResNetModel,
ResNetPreTrainedModel, ResNetPreTrainedModel,
......
...@@ -58,6 +58,9 @@ class ResNetConfig(PretrainedConfig): ...@@ -58,6 +58,9 @@ class ResNetConfig(PretrainedConfig):
are supported. are supported.
downsample_in_first_stage (`bool`, *optional*, defaults to `False`): downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2. If `True`, the first stage will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`,
`"stage3"`, `"stage4"`.
Example: Example:
```python ```python
...@@ -85,6 +88,7 @@ class ResNetConfig(PretrainedConfig): ...@@ -85,6 +88,7 @@ class ResNetConfig(PretrainedConfig):
layer_type="bottleneck", layer_type="bottleneck",
hidden_act="relu", hidden_act="relu",
downsample_in_first_stage=False, downsample_in_first_stage=False,
out_features=None,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -97,6 +101,16 @@ class ResNetConfig(PretrainedConfig): ...@@ -97,6 +101,16 @@ class ResNetConfig(PretrainedConfig):
self.layer_type = layer_type self.layer_type = layer_type
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.downsample_in_first_stage = downsample_in_first_stage self.downsample_in_first_stage = downsample_in_first_stage
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features
class ResNetOnnxConfig(OnnxConfig): class ResNetOnnxConfig(OnnxConfig):
......
...@@ -23,12 +23,19 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -23,12 +23,19 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import (
BackboneOutput,
BaseModelOutputWithNoAttention, BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention, BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention, ImageClassifierOutputWithNoAttention,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_resnet import ResNetConfig from .configuration_resnet import ResNetConfig
...@@ -416,3 +423,69 @@ class ResNetForImageClassification(ResNetPreTrainedModel): ...@@ -416,3 +423,69 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
return (loss,) + output if loss is not None else output return (loss,) + output if loss is not None else output
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
@add_start_docstrings(
"""
ResNet backbone, to be used with frameworks like DETR and MaskFormer.
""",
RESNET_START_DOCSTRING,
)
class ResNetBackbone(ResNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.stage_names = config.stage_names
self.resnet = ResNetModel(config)
self.out_features = config.out_features
self.out_feature_channels = {
"stem": config.embedding_size,
"stage1": config.hidden_sizes[0],
"stage2": config.hidden_sizes[1],
"stage3": config.hidden_sizes[2],
"stage4": config.hidden_sizes[3],
}
# initialize weights and apply final processing
self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(self, pixel_values: Optional[torch.FloatTensor] = None) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = AutoBackbone.from_pretrained("microsoft/resnet-50")
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
```"""
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True)
hidden_states = outputs.hidden_states
feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)
return BackboneOutput(feature_maps=feature_maps)
...@@ -437,6 +437,13 @@ MODEL_MAPPING = None ...@@ -437,6 +437,13 @@ MODEL_MAPPING = None
MODEL_WITH_LM_HEAD_MAPPING = None MODEL_WITH_LM_HEAD_MAPPING = None
class AutoBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModel(metaclass=DummyObject): class AutoModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -4523,6 +4530,13 @@ def load_tf_weights_in_rembert(*args, **kwargs): ...@@ -4523,6 +4530,13 @@ def load_tf_weights_in_rembert(*args, **kwargs):
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ResNetBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ResNetForImageClassification(metaclass=DummyObject): class ResNetForImageClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -30,7 +30,7 @@ if is_torch_available(): ...@@ -30,7 +30,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import ResNetForImageClassification, ResNetModel from transformers import ResNetBackbone, ResNetForImageClassification, ResNetModel
from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -55,6 +55,7 @@ class ResNetModelTester: ...@@ -55,6 +55,7 @@ class ResNetModelTester:
hidden_act="relu", hidden_act="relu",
num_labels=3, num_labels=3,
scope=None, scope=None,
out_features=["stage1", "stage2", "stage3", "stage4"],
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -69,6 +70,7 @@ class ResNetModelTester: ...@@ -69,6 +70,7 @@ class ResNetModelTester:
self.num_labels = num_labels self.num_labels = num_labels
self.scope = scope self.scope = scope
self.num_stages = len(hidden_sizes) self.num_stages = len(hidden_sizes)
self.out_features = out_features
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -89,6 +91,7 @@ class ResNetModelTester: ...@@ -89,6 +91,7 @@ class ResNetModelTester:
depths=self.depths, depths=self.depths,
hidden_act=self.hidden_act, hidden_act=self.hidden_act,
num_labels=self.num_labels, num_labels=self.num_labels,
out_features=self.out_features,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -110,6 +113,19 @@ class ResNetModelTester: ...@@ -110,6 +113,19 @@ class ResNetModelTester:
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_backbone(self, config, pixel_values, labels):
model = ResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8])
# verify channels
self.parent.assertListEqual(model.channels, config.hidden_sizes)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs config, pixel_values, labels = config_and_inputs
...@@ -176,6 +192,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -176,6 +192,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -47,6 +47,7 @@ PRIVATE_MODELS = [ ...@@ -47,6 +47,7 @@ PRIVATE_MODELS = [
# Being in this list is an exception and should **not** be the rule. # Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested # models to ignore for not tested
"ResNetBackbone", # Backbones have their own tests.
"CLIPSegDecoder", # Building part of bigger (tested) model. "CLIPSegDecoder", # Building part of bigger (tested) model.
"TableTransformerEncoder", # Building part of bigger (tested) model. "TableTransformerEncoder", # Building part of bigger (tested) model.
"TableTransformerDecoder", # Building part of bigger (tested) model. "TableTransformerDecoder", # Building part of bigger (tested) model.
...@@ -668,6 +669,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -668,6 +669,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"PyTorchBenchmarkArguments", "PyTorchBenchmarkArguments",
"TensorFlowBenchmark", "TensorFlowBenchmark",
"TensorFlowBenchmarkArguments", "TensorFlowBenchmarkArguments",
"ResNetBackbone",
"AutoBackbone",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment