Unverified Commit 9d99489f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Add TFData2VecVision for semantic segmentation (#17271)



* feat: initial implementation of data2vec segmentation model in TF.

* chore: minor corrections to make the segmenter work.

* chore: removed unncessary files.

* chore: add tests and other modifications.

* fix: loss computation for segmentation.

* chore: remove unused variable.

* chore: formatting.

* added a dummy adaptive pooling layer.

* removed unnecessary file.

* potentially add identifiers to layer names.

* fix: layer naming.

* chore: removed unnecessary print.

* Skipping unneeded test

* chore: add logging to debug tolerance.

* fix: segmentation tests for tfdata2vecvision

* chore: make style.

* fix: layer names, assertion to be resolved.

* Bumping test tolerance a bit

* chore: bump the tol in PT test.
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
parent 78c695eb
......@@ -42,7 +42,7 @@ Tips:
This model was contributed by [edugp](https://huggingface.co/edugp) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
[sayakpaul](https://github.com/sayakpaul) contributed Data2Vec for vision in TensorFlow.
[sayakpaul](https://github.com/sayakpaul) and [Rocketknight1](https://github.com/Rocketknight1) contributed Data2Vec for vision in TensorFlow.
The original code (for NLP and Speech) can be found [here](https://github.com/pytorch/fairseq/tree/main/examples/data2vec).
The original code for vision can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit).
......@@ -145,3 +145,8 @@ The original code for vision can be found [here](https://github.com/facebookrese
[[autodoc]] TFData2VecVisionForImageClassification
- call
## TFData2VecVisionForSemanticSegmentation
[[autodoc]] TFData2VecVisionForSemanticSegmentation
- call
......@@ -2036,6 +2036,7 @@ else:
_import_structure["models.data2vec"].extend(
[
"TFData2VecVisionForImageClassification",
"TFData2VecVisionForSemanticSegmentation",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]
......@@ -4342,6 +4343,7 @@ if TYPE_CHECKING:
)
from .models.data2vec import (
TFData2VecVisionForImageClassification,
TFData2VecVisionForSemanticSegmentation,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)
......
......@@ -607,6 +607,43 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
encoder_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFSemanticSegmenterOutput(ModelOutput):
"""
Base class for outputs of semantic segmentation models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Classification scores for each pixel.
<Tip warning={true}>
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.
</Tip>
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFMultipleChoiceModelOutput(ModelOutput):
"""
......
......@@ -163,6 +163,10 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if "running_var" in key:
new_key = key.replace("running_var", "moving_variance")
if "running_mean" in key:
new_key = key.replace("running_mean", "moving_mean")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
......
......@@ -178,6 +178,13 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
]
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Semantic Segmentation mapping
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
]
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
......@@ -365,6 +372,9 @@ TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
......@@ -440,6 +450,15 @@ TFAutoModelForImageClassification = auto_class_update(
)
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
TF_AutoModelForSemanticSegmentation = auto_class_update(
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
......
......@@ -73,6 +73,7 @@ else:
if is_tf_available():
_import_structure["modeling_tf_data2vec_vision"] = [
"TFData2VecVisionForImageClassification",
"TFData2VecVisionForSemanticSegmentation",
"TFData2VecVisionModel",
"TFData2VecVisionPreTrainedModel",
]
......@@ -127,6 +128,7 @@ if TYPE_CHECKING:
if is_tf_available():
from .modeling_tf_data2vec_vision import (
TFData2VecVisionForImageClassification,
TFData2VecVisionForSemanticSegmentation,
TFData2VecVisionModel,
TFData2VecVisionPreTrainedModel,
)
......
......@@ -17,7 +17,7 @@
import collections.abc
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
......@@ -25,7 +25,12 @@ import tensorflow as tf
from transformers.tf_utils import shape_list, stable_softmax
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFSemanticSegmenterOutput,
TFSequenceClassifierOutput,
)
from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
......@@ -34,7 +39,13 @@ from ...modeling_tf_utils import (
keras_serializable,
unpack_inputs,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_data2vec_vision import Data2VecVisionConfig
......@@ -978,3 +989,464 @@ class TFData2VecVisionForImageClassification(TFData2VecVisionPreTrainedModel, TF
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class TFData2VecVisionConvModule(tf.keras.layers.Layer):
"""
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(
self,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
padding: str = "valid",
bias: bool = False,
dilation: Union[int, Tuple[int, int]] = 1,
**kwargs
) -> None:
super().__init__(**kwargs)
self.conv = tf.keras.layers.Conv2D(
filters=out_channels,
kernel_size=kernel_size,
padding=padding,
use_bias=bias,
dilation_rate=dilation,
name="conv",
)
self.bn = tf.keras.layers.BatchNormalization(name="bn")
self.activation = tf.nn.relu
def call(self, input: tf.Tensor) -> tf.Tensor:
output = self.conv(input)
output = self.bn(output)
output = self.activation(output)
return output
# Copied from:
# https://gist.github.com/Rocketknight1/43abbe6e73f1008e6e459486e01e0ceb
class TFAdaptiveAvgPool1D(tf.keras.layers.Layer):
def __init__(self, output_dim, mode="dense", **kwargs):
super().__init__(**kwargs)
self.output_dim = output_dim
self.mode = mode
self.map = None
def build(self, input_shape):
super().build(input_shape)
"""We pre-compute the sparse matrix for the build() step once. The below code comes
from https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993#63603993."""
def get_kernels(ind, outd) -> List:
"""Returns a List [(kernel_offset_start,kernel_length)] defining all the pooling kernels for a 1-D adaptive
pooling layer that takes an input of dimension `ind` and yields an output of dimension `outd`"""
def start_index(a, b, c):
return math.floor((float(a) * float(c)) / b)
def end_index(a, b, c):
return math.ceil((float(a + 1) * float(c)) / b)
results = []
for ow in range(outd):
start = start_index(ow, outd, ind)
end = end_index(ow, outd, ind)
sz = end - start
results.append((start, sz))
return results
in_dim = int(input_shape[-1])
kernels = get_kernels(in_dim, self.output_dim)
sparse_map = np.zeros((in_dim, self.output_dim), dtype=np.float32)
for i, kernel in enumerate(kernels):
sparse_map[kernel[0] : kernel[0] + kernel[1], i] = 1 / kernel[1]
if self.mode == "dense":
self.map = tf.constant(sparse_map)
else:
self.map = tf.sparse.from_dense(sparse_map)
def call(self, inputs):
if self.mode == "dense":
return inputs @ self.map
else:
input_dims = inputs.shape
input_matrix = tf.reshape(inputs, (-1, input_dims[-1]))
out = tf.sparse.sparse_dense_matmul(input_matrix, self.map)
return tf.reshape(out, input_dims[:-1].as_list() + [-1])
def get_config(self):
config = super().get_config()
config.update({"output_dim": self.output_dim, "mode": self.mode})
return config
class TFAdaptiveAvgPool2D(tf.keras.layers.Layer):
def __init__(self, output_shape, mode="dense", **kwargs):
super().__init__(**kwargs)
self.mode = mode
self.h_pool = TFAdaptiveAvgPool1D(output_shape[0], mode=mode, name="h_pool")
self.w_pool = TFAdaptiveAvgPool1D(output_shape[1], mode=mode, name="w_pool")
def call(self, inputs):
# Rearrange from NHWC -> NCHW
inputs = tf.transpose(inputs, perm=[0, 3, 1, 2])
# Perform W-pooling
inputs = self.w_pool(inputs)
# Rearrange NCHW -> NCWH
inputs = tf.transpose(inputs, perm=[0, 1, 3, 2])
# Perform H-pooling
inputs = self.h_pool(inputs)
# Rearrange from NCWH -> NHWC
inputs = tf.transpose(inputs, perm=[0, 3, 2, 1])
return inputs
def get_config(self):
config = super().get_config()
config.update({"mode": self.mode})
return config
class TFData2VecVisionPyramidPoolingModule(tf.keras.layers.Layer):
"""
Pyramid Pooling Module (PPM) used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
channels (int): Channels after modules, before conv_seg.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, pool_scales: Tuple[int, ...], channels: int, **kwargs) -> None:
super().__init__(**kwargs)
self.pool_scales = pool_scales
self.channels = channels
self.layer_list = []
for idx, pool_scale in enumerate(pool_scales):
pool_scale = pool_scale if isinstance(pool_scale, collections.abc.Iterable) else (pool_scale, pool_scale)
self.layer_list.append(
[
TFAdaptiveAvgPool2D(output_shape=pool_scale),
TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"{idx}.1"),
]
)
def call(self, x: tf.Tensor) -> List[tf.Tensor]:
ppm_outs = []
inputs = x
for ppm in self.layer_list:
for layer_module in ppm:
ppm_out = layer_module(x)
x = ppm_out
upsampled_ppm_out = tf.image.resize(ppm_out, size=shape_list(inputs)[1:-1], method="bilinear")
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
class TFData2VecVisionUperHead(tf.keras.layers.Layer):
"""
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
[UPerNet](https://arxiv.org/abs/1807.10221).
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, config: Data2VecVisionConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
self.channels = config.hidden_size
self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
# PSP Module
self.psp_modules = TFData2VecVisionPyramidPoolingModule(self.pool_scales, self.channels, name="psp_modules")
self.bottleneck = TFData2VecVisionConvModule(self.channels, kernel_size=3, padding="same", name="bottleneck")
# FPN Module
self.lateral_convs = []
self.fpn_convs = []
for idx, _ in enumerate(self.in_channels[:-1]): # skip the top layer
l_conv = TFData2VecVisionConvModule(out_channels=self.channels, kernel_size=1, name=f"lateral_convs.{idx}")
fpn_conv = TFData2VecVisionConvModule(
out_channels=self.channels, kernel_size=3, padding="same", name=f"fpn_convs.{idx}"
)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = TFData2VecVisionConvModule(
out_channels=self.channels, kernel_size=3, padding="same", name="fpn_bottleneck"
)
def psp_forward(self, inputs):
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = tf.concat(psp_outs, axis=-1)
output = self.bottleneck(psp_outs)
return output
def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
# build laterals
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
laterals.append(self.psp_forward(encoder_hidden_states))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = shape_list(laterals[i - 1])[1:-1]
laterals[i - 1] = laterals[i - 1] + tf.image.resize(laterals[i], size=prev_shape, method="bilinear")
# build outputs
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = tf.image.resize(fpn_outs[i], size=shape_list(fpn_outs[0])[1:-1], method="bilinear")
fpn_outs = tf.concat(fpn_outs, axis=-1)
output = self.fpn_bottleneck(fpn_outs)
output = self.classifier(output)
return output
class TFData2VecVisionFCNHead(tf.keras.layers.Layer):
"""
Fully Convolution Networks for Semantic Segmentation. This head is implemented from
[FCNNet](https://arxiv.org/abs/1411.4038).
Args:
config (Data2VecVisionConfig): Configuration.
kernel_size (int): The kernel size for convs in the head. Default: 3.
dilation (int): The dilation rate for convs in the head. Default: 1.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(
self,
config: Data2VecVisionConfig,
in_index: int = 2,
kernel_size: int = 3,
dilation: Union[int, Tuple[int, int]] = 1,
**kwargs
) -> None:
super().__init__(**kwargs)
self.in_channels = config.hidden_size
self.channels = config.auxiliary_channels
self.num_convs = config.auxiliary_num_convs
self.concat_input = config.auxiliary_concat_input
self.in_index = in_index
convs = []
convs.append(
TFData2VecVisionConvModule(
out_channels=self.channels,
kernel_size=kernel_size,
padding="same",
dilation=dilation,
name="convs.0",
)
)
for i in range(self.num_convs - 1):
convs.append(
TFData2VecVisionConvModule(
out_channels=self.channels,
kernel_size=kernel_size,
padding="same",
dilation=dilation,
name=f"conv_module_{i+2}",
)
)
if self.num_convs == 0:
self.convs = [tf.identity]
else:
self.convs = convs
if self.concat_input:
self.conv_cat = TFData2VecVisionConvModule(
out_channels=self.channels, kernel_size=kernel_size, padding="same", name="conv_cat"
)
self.classifier = tf.keras.layers.Conv2D(config.num_labels, kernel_size=1, name="classifier")
def call(self, encoder_hidden_states: tf.Tensor) -> tf.Tensor:
# just take the relevant feature maps
hidden_states = encoder_hidden_states[self.in_index]
output = hidden_states
for layer_module in self.convs:
output = layer_module(output)
if self.concat_input:
output = self.conv_cat(tf.concat([hidden_states, output], axis=-1))
output = self.classifier(output)
return output
@add_start_docstrings(
"""
Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
""",
DATA2VEC_VISION_START_DOCSTRING,
)
class TFData2VecVisionForSemanticSegmentation(TFData2VecVisionPreTrainedModel):
def __init__(self, config: Data2VecVisionConfig, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.data2vec_vision = TFData2VecVisionMainLayer(config, add_pooling_layer=False, name="data2vec_vision")
# FPNs
self.fpn1 = [
tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.0"),
tf.keras.layers.BatchNormalization(name="fpn1.1"),
tf.keras.layers.Activation("gelu"),
tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn1.3"),
]
self.fpn2 = [tf.keras.layers.Conv2DTranspose(config.hidden_size, kernel_size=2, strides=2, name="fpn2.0")]
self.fpn3 = tf.identity
self.fpn4 = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)
# Semantic segmentation head(s)
self.decode_head = TFData2VecVisionUperHead(config, name="decode_head")
self.auxiliary_head = (
TFData2VecVisionFCNHead(config, name="auxiliary_head") if config.use_auxiliary_head else None
)
def compute_loss(self, logits, auxiliary_logits, labels):
# upsample logits to the images' original size
if len(shape_list(labels)) > 3:
label_interp_shape = shape_list(labels)[1:-1]
else:
label_interp_shape = shape_list(labels)[-2:]
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
if auxiliary_logits is not None:
upsampled_auxiliary_logits = tf.image.resize(auxiliary_logits, size=label_interp_shape, method="bilinear")
# compute weighted loss
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
# Copied from https://www.tensorflow.org/text/tutorials/transformer#loss_and_metrics.
# Utility to mask the index to ignore during computing the loss.
def masked_loss(real, pred):
mask = tf.math.logical_not(tf.math.equal(real, self.config.semantic_loss_ignore_index))
loss_ = loss_fct(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_sum(loss_) / tf.reduce_sum(mask)
main_loss = masked_loss(labels, upsampled_logits)
auxiliary_loss = masked_loss(labels, upsampled_auxiliary_logits)
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
return loss
@unpack_inputs
@add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
head_mask: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, TFSemanticSegmenterOutput]:
r"""
labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, TFData2VecVisionForSemanticSegmentation
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/data2vec-vision-base")
>>> model = TFData2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> # logits are of shape (batch_size, num_labels, height, width)
>>> logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.data2vec_vision(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
return_dict=return_dict,
)
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
# only keep certain features, and reshape
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
batch_size = shape_list(pixel_values)[0]
patch_resolution = self.config.image_size // self.config.patch_size
def reshape_features(x):
x = tf.reshape(x, (batch_size, patch_resolution, patch_resolution, -1))
return x
features = [reshape_features(x[:, 1:, :]) for x in features]
# apply FPNs
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for module in ops[0]:
features[0] = module(features[0])
features[1] = ops[1][0](features[1])
for i in range(len(features[2:])):
features[i + 2] = ops[i + 2](features[i + 2])
logits = self.decode_head(features)
# Tranpose the logits to maintain consistency in the output formats.
transposed_logits = tf.transpose(logits, perm=[0, 3, 1, 2])
auxiliary_logits = None
if self.auxiliary_head is not None:
auxiliary_logits = self.auxiliary_head(features)
loss = None
if labels is not None:
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.compute_loss(logits, auxiliary_logits, labels)
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSemanticSegmenterOutput(
loss=loss,
logits=transposed_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
......@@ -759,6 +759,13 @@ class TFData2VecVisionForImageClassification(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFData2VecVisionForSemanticSegmentation(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFData2VecVisionModel(metaclass=DummyObject):
_backends = ["tf"]
......
......@@ -389,6 +389,10 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
......
......@@ -31,7 +31,11 @@ from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_te
if is_tf_available():
import tensorflow as tf
from transformers import TFData2VecVisionForImageClassification, TFData2VecVisionModel
from transformers import (
TFData2VecVisionForImageClassification,
TFData2VecVisionForSemanticSegmentation,
TFData2VecVisionModel,
)
from transformers.models.data2vec.modeling_tf_data2vec_vision import (
TF_DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST,
)
......@@ -142,6 +146,18 @@ class TFData2VecVisionModelTester:
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.num_labels
model = TFData2VecVisionForSemanticSegmentation(config)
result = model(pixel_values, training=False)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
)
result = model(pixel_values, labels=pixel_labels)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels, pixel_labels = config_and_inputs
......@@ -162,7 +178,11 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
attention_mask and seq_length.
"""
all_model_classes = (TFData2VecVisionModel, TFData2VecVisionForImageClassification) if is_tf_available() else ()
all_model_classes = (
(TFData2VecVisionModel, TFData2VecVisionForImageClassification, TFData2VecVisionForSemanticSegmentation)
if is_tf_available()
else ()
)
test_pruning = False
test_onnx = False
......@@ -208,6 +228,14 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
@unittest.skip("Test was written for TF 1.x and isn't really relevant here")
def test_compile_tf_model(self):
pass
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
......@@ -354,6 +382,10 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
# Overriding this method since the base method won't be compatible with Data2VecVision.
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment