Unverified Commit 441658dd authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Add focalnet backbone (#23104)

Adds FocalNet backbone to return features from all stages
parent ca7eb27e
...@@ -1623,6 +1623,7 @@ else: ...@@ -1623,6 +1623,7 @@ else:
_import_structure["models.focalnet"].extend( _import_structure["models.focalnet"].extend(
[ [
"FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST", "FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"FocalNetBackbone",
"FocalNetForImageClassification", "FocalNetForImageClassification",
"FocalNetForMaskedImageModeling", "FocalNetForMaskedImageModeling",
"FocalNetModel", "FocalNetModel",
...@@ -5178,6 +5179,7 @@ if TYPE_CHECKING: ...@@ -5178,6 +5179,7 @@ if TYPE_CHECKING:
) )
from .models.focalnet import ( from .models.focalnet import (
FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST, FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FocalNetBackbone,
FocalNetForImageClassification, FocalNetForImageClassification,
FocalNetForMaskedImageModeling, FocalNetForMaskedImageModeling,
FocalNetModel, FocalNetModel,
......
...@@ -980,6 +980,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ...@@ -980,6 +980,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextBackbone"), ("convnext", "ConvNextBackbone"),
("convnextv2", "ConvNextV2Backbone"), ("convnextv2", "ConvNextV2Backbone"),
("dinat", "DinatBackbone"), ("dinat", "DinatBackbone"),
("focalnet", "FocalNetBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"), ("nat", "NatBackbone"),
("resnet", "ResNetBackbone"), ("resnet", "ResNetBackbone"),
......
...@@ -30,6 +30,7 @@ else: ...@@ -30,6 +30,7 @@ else:
"FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST", "FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"FocalNetForImageClassification", "FocalNetForImageClassification",
"FocalNetForMaskedImageModeling", "FocalNetForMaskedImageModeling",
"FocalNetBackbone",
"FocalNetModel", "FocalNetModel",
"FocalNetPreTrainedModel", "FocalNetPreTrainedModel",
] ]
...@@ -45,6 +46,7 @@ if TYPE_CHECKING: ...@@ -45,6 +46,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_focalnet import ( from .modeling_focalnet import (
FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST, FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST,
FocalNetBackbone,
FocalNetForImageClassification, FocalNetForImageClassification,
FocalNetForMaskedImageModeling, FocalNetForMaskedImageModeling,
FocalNetModel, FocalNetModel,
......
...@@ -47,6 +47,8 @@ class FocalNetConfig(PretrainedConfig): ...@@ -47,6 +47,8 @@ class FocalNetConfig(PretrainedConfig):
use_conv_embed (`bool`, *optional*, defaults to `False`): use_conv_embed (`bool`, *optional*, defaults to `False`):
Whether to use convolutional embedding. The authors noted that using convolutional embedding usually Whether to use convolutional embedding. The authors noted that using convolutional embedding usually
improve the performance, but it's not used by default. improve the performance, but it's not used by default.
hidden_sizes (`List[int]`, *optional*, defaults to `[192, 384, 768, 768]`):
Dimensionality (hidden size) at each stage.
depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
Depth (number of layers) of each stage in the encoder. Depth (number of layers) of each stage in the encoder.
focal_levels (`list(int)`, *optional*, defaults to `[2, 2, 2, 2]`): focal_levels (`list(int)`, *optional*, defaults to `[2, 2, 2, 2]`):
...@@ -78,6 +80,14 @@ class FocalNetConfig(PretrainedConfig): ...@@ -78,6 +80,14 @@ class FocalNetConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
encoder_stride (`int`, `optional`, defaults to 32): encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling. Factor to increase the spatial resolution by in the decoder head for masked image modeling.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage.
Example: Example:
...@@ -102,6 +112,7 @@ class FocalNetConfig(PretrainedConfig): ...@@ -102,6 +112,7 @@ class FocalNetConfig(PretrainedConfig):
num_channels=3, num_channels=3,
embed_dim=96, embed_dim=96,
use_conv_embed=False, use_conv_embed=False,
hidden_sizes=[192, 384, 768, 768],
depths=[2, 2, 6, 2], depths=[2, 2, 6, 2],
focal_levels=[2, 2, 2, 2], focal_levels=[2, 2, 2, 2],
focal_windows=[3, 3, 3, 3], focal_windows=[3, 3, 3, 3],
...@@ -117,6 +128,8 @@ class FocalNetConfig(PretrainedConfig): ...@@ -117,6 +128,8 @@ class FocalNetConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
encoder_stride=32, encoder_stride=32,
out_features=None,
out_indices=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -126,6 +139,7 @@ class FocalNetConfig(PretrainedConfig): ...@@ -126,6 +139,7 @@ class FocalNetConfig(PretrainedConfig):
self.num_channels = num_channels self.num_channels = num_channels
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.use_conv_embed = use_conv_embed self.use_conv_embed = use_conv_embed
self.hidden_sizes = hidden_sizes
self.depths = depths self.depths = depths
self.focal_levels = focal_levels self.focal_levels = focal_levels
self.focal_windows = focal_windows self.focal_windows = focal_windows
...@@ -141,3 +155,36 @@ class FocalNetConfig(PretrainedConfig): ...@@ -141,3 +155,36 @@ class FocalNetConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
if out_features is not None and out_indices is not None:
if len(out_features) != len(out_indices):
raise ValueError("out_features and out_indices should have the same length if both are set")
elif out_features != [self.stage_names[idx] for idx in out_indices]:
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
if out_features is None and out_indices is not None:
out_features = [self.stage_names[idx] for idx in out_indices]
elif out_features is not None and out_indices is None:
out_indices = [self.stage_names.index(feature) for feature in out_features]
elif out_features is None and out_indices is None:
out_features = [self.stage_names[-1]]
out_indices = [len(self.stage_names) - 1]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError("out_indices should be a list or tuple")
for idx in out_indices:
if idx >= len(self.stage_names):
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
self.out_features = out_features
self.out_indices = out_indices
...@@ -56,7 +56,6 @@ def get_focalnet_config(model_name): ...@@ -56,7 +56,6 @@ def get_focalnet_config(model_name):
embed_dim = 128 embed_dim = 128
elif "large" in model_name: elif "large" in model_name:
embed_dim = 192 embed_dim = 192
focal_windows = [5, 5, 5, 5]
elif "xlarge" in model_name: elif "xlarge" in model_name:
embed_dim = 256 embed_dim = 256
elif "huge" in model_name: elif "huge" in model_name:
...@@ -130,7 +129,10 @@ def convert_focalnet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hu ...@@ -130,7 +129,10 @@ def convert_focalnet_checkpoint(model_name, pytorch_dump_folder_path, push_to_hu
"focalnet-small-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth", "focalnet-small-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth",
"focalnet-base": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth", "focalnet-base": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth",
"focalnet-base-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth", "focalnet-base-lrf": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth",
"focalnet-large": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth", "focalnet-large-lrf-fl3": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth",
"focalnet-large-lrf-fl4": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth",
"focalnet-xlarge-lrf-fl3": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth",
"focalnet-xlarge-lrf-fl4": "https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth",
} }
# fmt: on # fmt: on
......
...@@ -26,7 +26,8 @@ from torch import nn ...@@ -26,7 +26,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -209,7 +210,6 @@ class FocalNetEmbeddings(nn.Module): ...@@ -209,7 +210,6 @@ class FocalNetEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings, output_dimensions return embeddings, output_dimensions
...@@ -971,3 +971,81 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel): ...@@ -971,3 +971,81 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
reshaped_hidden_states=outputs.reshaped_hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states,
) )
@add_start_docstrings(
"""
FocalNet backbone, to be used with frameworks like X-Decoder.
""",
FOCALNET_START_DOCSTRING,
)
class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
self.stage_names = config.stage_names
self.focalnet = FocalNetModel(config)
self.num_features = [config.embed_dim] + config.hidden_sizes
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if config.out_indices is not None:
self.out_indices = config.out_indices
else:
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
# initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
>>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.focalnet(pixel_values, output_hidden_states=True, return_dict=True)
hidden_states = outputs.reshaped_hidden_states
feature_maps = ()
for idx, stage in enumerate(self.stage_names):
if stage in self.out_features:
feature_maps += (hidden_states[idx],)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=None,
)
...@@ -3002,6 +3002,13 @@ class FNetPreTrainedModel(metaclass=DummyObject): ...@@ -3002,6 +3002,13 @@ class FNetPreTrainedModel(metaclass=DummyObject):
FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST = None FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
class FocalNetBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FocalNetForImageClassification(metaclass=DummyObject): class FocalNetForImageClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -22,6 +22,7 @@ from transformers import FocalNetConfig ...@@ -22,6 +22,7 @@ from transformers import FocalNetConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
...@@ -30,7 +31,12 @@ if is_torch_available(): ...@@ -30,7 +31,12 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import FocalNetForImageClassification, FocalNetForMaskedImageModeling, FocalNetModel from transformers import (
FocalNetBackbone,
FocalNetForImageClassification,
FocalNetForMaskedImageModeling,
FocalNetModel,
)
from transformers.models.focalnet.modeling_focalnet import FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.focalnet.modeling_focalnet import FOCALNET_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -48,6 +54,7 @@ class FocalNetModelTester: ...@@ -48,6 +54,7 @@ class FocalNetModelTester:
patch_size=2, patch_size=2,
num_channels=3, num_channels=3,
embed_dim=16, embed_dim=16,
hidden_sizes=[32, 64, 128],
depths=[1, 2, 1], depths=[1, 2, 1],
num_heads=[2, 2, 4], num_heads=[2, 2, 4],
window_size=2, window_size=2,
...@@ -67,6 +74,7 @@ class FocalNetModelTester: ...@@ -67,6 +74,7 @@ class FocalNetModelTester:
type_sequence_label_size=10, type_sequence_label_size=10,
encoder_stride=8, encoder_stride=8,
out_features=["stage1", "stage2"], out_features=["stage1", "stage2"],
out_indices=[1, 2],
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -74,6 +82,7 @@ class FocalNetModelTester: ...@@ -74,6 +82,7 @@ class FocalNetModelTester:
self.patch_size = patch_size self.patch_size = patch_size
self.num_channels = num_channels self.num_channels = num_channels
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.hidden_sizes = hidden_sizes
self.depths = depths self.depths = depths
self.num_heads = num_heads self.num_heads = num_heads
self.window_size = window_size self.window_size = window_size
...@@ -93,6 +102,7 @@ class FocalNetModelTester: ...@@ -93,6 +102,7 @@ class FocalNetModelTester:
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.out_features = out_features self.out_features = out_features
self.out_indices = out_indices
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -111,6 +121,7 @@ class FocalNetModelTester: ...@@ -111,6 +121,7 @@ class FocalNetModelTester:
patch_size=self.patch_size, patch_size=self.patch_size,
num_channels=self.num_channels, num_channels=self.num_channels,
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
hidden_sizes=self.hidden_sizes,
depths=self.depths, depths=self.depths,
num_heads=self.num_heads, num_heads=self.num_heads,
window_size=self.window_size, window_size=self.window_size,
...@@ -126,6 +137,7 @@ class FocalNetModelTester: ...@@ -126,6 +137,7 @@ class FocalNetModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, encoder_stride=self.encoder_stride,
out_features=self.out_features, out_features=self.out_features,
out_indices=self.out_indices,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -139,6 +151,35 @@ class FocalNetModelTester: ...@@ -139,6 +151,35 @@ class FocalNetModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = FocalNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size, 8, 8])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[:-1])
# verify backbone works with out_features=None
config.out_features = None
model = FocalNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.image_size * 2, 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = FocalNetForMaskedImageModeling(config=config) model = FocalNetForMaskedImageModeling(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -191,6 +232,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -191,6 +232,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
FocalNetModel, FocalNetModel,
FocalNetForImageClassification, FocalNetForImageClassification,
FocalNetForMaskedImageModeling, FocalNetForMaskedImageModeling,
FocalNetBackbone,
) )
if is_torch_available() if is_torch_available()
else () else ()
...@@ -204,7 +246,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -204,7 +246,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def setUp(self): def setUp(self):
self.model_tester = FocalNetModelTester(self) self.model_tester = FocalNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37) self.config_tester = ConfigTester(self, config_class=FocalNetConfig, embed_dim=37, has_text_modality=False)
def test_config(self): def test_config(self):
self.create_and_test_config_common_properties() self.create_and_test_config_common_properties()
...@@ -222,6 +264,10 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -222,6 +264,10 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_masked_image_modeling(self): def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
...@@ -234,14 +280,14 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -234,14 +280,14 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="FocalNet Transformer does not use feedforward chunking") @unittest.skip(reason="FocalNet does not use feedforward chunking")
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes[:-1]:
model = model_class(config) model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings() x = model.get_output_embeddings()
...@@ -250,7 +296,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -250,7 +296,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes[:-1]:
model = model_class(config) model = model_class(config)
signature = inspect.signature(model.forward) signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
...@@ -309,7 +355,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -309,7 +355,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
else (self.model_tester.image_size, self.model_tester.image_size) else (self.model_tester.image_size, self.model_tester.image_size)
) )
for model_class in self.all_model_classes: for model_class in self.all_model_classes[:-1]:
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, image_size) self.check_hidden_states_output(inputs_dict, config, model_class, image_size)
...@@ -337,7 +383,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -337,7 +383,7 @@ class FocalNetModelTest(ModelTesterMixin, unittest.TestCase):
padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0]) padded_height = image_size[0] + patch_size[0] - (image_size[0] % patch_size[0])
padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1]) padded_width = image_size[1] + patch_size[1] - (image_size[1] % patch_size[1])
for model_class in self.all_model_classes: for model_class in self.all_model_classes[:-1]:
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width)) self.check_hidden_states_output(inputs_dict, config, model_class, (padded_height, padded_width))
...@@ -393,3 +439,14 @@ class FocalNetModelIntegrationTest(unittest.TestCase): ...@@ -393,3 +439,14 @@ class FocalNetModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([0.2166, -0.4368, 0.2191]).to(torch_device) expected_slice = torch.tensor([0.2166, -0.4368, 0.2191]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
self.assertTrue(outputs.logits.argmax(dim=-1).item(), 281) self.assertTrue(outputs.logits.argmax(dim=-1).item(), 281)
@require_torch
class FocalNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
all_model_classes = (FocalNetBackbone,) if is_torch_available() else ()
config_class = FocalNetConfig
has_attentions = False
def setUp(self):
self.model_tester = FocalNetModelTester(self)
...@@ -135,6 +135,8 @@ class BackboneTesterMixin: ...@@ -135,6 +135,8 @@ class BackboneTesterMixin:
# Verify num_features has been initialized in the backbone init # Verify num_features has been initialized in the backbone init
self.assertIsNotNone(backbone.num_features) self.assertIsNotNone(backbone.num_features)
self.assertTrue(len(backbone.channels) == len(backbone.out_indices)) self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
print(backbone.stage_names)
print(backbone.num_features)
self.assertTrue(len(backbone.stage_names) == len(backbone.num_features)) self.assertTrue(len(backbone.stage_names) == len(backbone.num_features))
self.assertTrue(len(backbone.channels) <= len(backbone.num_features)) self.assertTrue(len(backbone.channels) <= len(backbone.num_features))
self.assertTrue(len(backbone.out_feature_channels) == len(backbone.stage_names)) self.assertTrue(len(backbone.out_feature_channels) == len(backbone.stage_names))
......
...@@ -836,6 +836,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -836,6 +836,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"ConvNextBackbone", "ConvNextBackbone",
"ConvNextV2Backbone", "ConvNextV2Backbone",
"DinatBackbone", "DinatBackbone",
"FocalNetBackbone",
"MaskFormerSwinBackbone", "MaskFormerSwinBackbone",
"MaskFormerSwinConfig", "MaskFormerSwinConfig",
"MaskFormerSwinModel", "MaskFormerSwinModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment