Unverified Commit 519a677e authored by lumliolum's avatar lumliolum Committed by GitHub
Browse files

Added Beit model output class (#14133)



* add Beit model ouput class

* inherting from BaseModelOuputWithPooling

* updated docs if use_mean_pooling is False

* added beit specific outputs in model docs

* changed the import path

* Fix docs
Co-authored-by: default avatarNiels Rogge <niels.rogge1@gmail.com>
parent bbaa3eff
......@@ -63,6 +63,17 @@ This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The JA
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
<https://github.com/microsoft/unilm/tree/master/beit>`__.
BEiT specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.beit.modeling_beit.BeitModelOutputWithPooling
:members:
.. autoclass:: transformers.models.beit.modeling_flax_beit.FlaxBeitModelOutputWithPooling
:members:
BeitConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -17,6 +17,7 @@
import collections.abc
import math
from dataclasses import dataclass
import torch
import torch.utils.checkpoint
......@@ -42,6 +43,32 @@ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
"""
Class for outputs of :class:`~transformers.BeitModel`.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
will be returned.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
......@@ -585,7 +612,7 @@ class BeitModel(BeitPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
......@@ -646,7 +673,7 @@ class BeitModel(BeitPreTrainedModel):
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
return BeitModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
......
......@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Tuple
import numpy as np
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
......@@ -40,6 +41,29 @@ from ...modeling_flax_utils import (
from .configuration_beit import BeitConfig
@flax.struct.dataclass
class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling):
"""
Class for outputs of :class:`~transformers.FlaxBeitModel`.
Args:
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
will be returned.
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each
layer plus the initial embedding outputs.
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""
BEIT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
......@@ -674,7 +698,7 @@ class FlaxBeitModule(nn.Module):
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling(
return FlaxBeitModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
......@@ -711,7 +735,7 @@ FLAX_BEIT_MODEL_DOCSTRING = """
"""
overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBaseModelOutputWithPooling, config_class=BeitConfig)
append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig)
class FlaxBeitForMaskedImageModelingModule(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment