Unverified Commit 519a677e authored by lumliolum's avatar lumliolum Committed by GitHub
Browse files

Added Beit model output class (#14133)



* add Beit model ouput class

* inherting from BaseModelOuputWithPooling

* updated docs if use_mean_pooling is False

* added beit specific outputs in model docs

* changed the import path

* Fix docs
Co-authored-by: default avatarNiels Rogge <niels.rogge1@gmail.com>
parent bbaa3eff
...@@ -63,6 +63,17 @@ This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The JA ...@@ -63,6 +63,17 @@ This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The JA
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
<https://github.com/microsoft/unilm/tree/master/beit>`__. <https://github.com/microsoft/unilm/tree/master/beit>`__.
BEiT specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.beit.modeling_beit.BeitModelOutputWithPooling
:members:
.. autoclass:: transformers.models.beit.modeling_flax_beit.FlaxBeitModelOutputWithPooling
:members:
BeitConfig BeitConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections.abc import collections.abc
import math import math
from dataclasses import dataclass
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -42,6 +43,32 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -42,6 +43,32 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
@dataclass
class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
"""
Class for outputs of :class:`~transformers.BeitModel`.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
will be returned.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# Inspired by # Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py # https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals # From PyTorch internals
...@@ -585,7 +612,7 @@ class BeitModel(BeitPreTrainedModel): ...@@ -585,7 +612,7 @@ class BeitModel(BeitPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -646,7 +673,7 @@ class BeitModel(BeitPreTrainedModel): ...@@ -646,7 +673,7 @@ class BeitModel(BeitPreTrainedModel):
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BeitModelOutputWithPooling(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
......
...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Tuple ...@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Tuple
import numpy as np import numpy as np
import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -40,6 +41,29 @@ from ...modeling_flax_utils import ( ...@@ -40,6 +41,29 @@ from ...modeling_flax_utils import (
from .configuration_beit import BeitConfig from .configuration_beit import BeitConfig
@flax.struct.dataclass
class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling):
"""
Class for outputs of :class:`~transformers.FlaxBeitModel`.
Args:
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
will be returned.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each
layer plus the initial embedding outputs.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""
BEIT_START_DOCSTRING = r""" BEIT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
...@@ -674,7 +698,7 @@ class FlaxBeitModule(nn.Module): ...@@ -674,7 +698,7 @@ class FlaxBeitModule(nn.Module):
return (hidden_states,) + outputs[1:] return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:] return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling( return FlaxBeitModelOutputWithPooling(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
pooler_output=pooled, pooler_output=pooled,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
...@@ -711,7 +735,7 @@ FLAX_BEIT_MODEL_DOCSTRING = """ ...@@ -711,7 +735,7 @@ FLAX_BEIT_MODEL_DOCSTRING = """
""" """
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING) overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBaseModelOutputWithPooling, config_class=BeitConfig) append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig)
class FlaxBeitForMaskedImageModelingModule(nn.Module): class FlaxBeitForMaskedImageModelingModule(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment