"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fe3c8ab1af558b95f67f5fafc0c55f09fd2b09db"
Unverified Commit da9754a3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Align jax flax device name (#12987)

* [Flax] Align device name in docs

* make style

* fix import error
parent 07df5578
...@@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
Returns: Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of text model. obtained by applying the projection layer to the pooled output of text model.
""" """
if position_ids is None: if position_ids is None:
...@@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details. :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
Returns: Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of vision model. obtained by applying the projection layer to the pooled output of vision model.
""" """
......
...@@ -19,7 +19,6 @@ from abc import ABC ...@@ -19,7 +19,6 @@ from abc import ABC
import jax import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .utils.logging import get_logger from .utils.logging import get_logger
...@@ -30,7 +29,7 @@ logger = get_logger(__name__) ...@@ -30,7 +29,7 @@ logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
...@@ -38,14 +37,14 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" ...@@ -38,14 +37,14 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`): scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search search or log softmax for each vocabulary token when using beam search
kwargs: kwargs:
Additional logits processor specific kwargs. Additional logits processor specific kwargs.
Return: Return:
:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. :obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
""" """
...@@ -54,7 +53,7 @@ class FlaxLogitsProcessor(ABC): ...@@ -54,7 +53,7 @@ class FlaxLogitsProcessor(ABC):
"""Abstract base class for all logit processors that can be applied during generation.""" """Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""Flax method for processing logits.""" """Flax method for processing logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
...@@ -65,7 +64,7 @@ class FlaxLogitsWarper(ABC): ...@@ -65,7 +64,7 @@ class FlaxLogitsWarper(ABC):
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""Flax method for warping logits.""" """Flax method for warping logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
...@@ -81,9 +80,7 @@ class FlaxLogitsProcessorList(list): ...@@ -81,9 +80,7 @@ class FlaxLogitsProcessorList(list):
""" """
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
) -> jax_xla.DeviceArray:
for processor in self: for processor in self:
function_args = inspect.signature(processor.__call__).parameters function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 3: if len(function_args) > 3:
...@@ -111,9 +108,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): ...@@ -111,9 +108,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
self.temperature = temperature self.temperature = temperature
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
scores = scores / self.temperature scores = scores / self.temperature
return scores return scores
...@@ -141,9 +136,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): ...@@ -141,9 +136,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1]) topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
mask_scores = jnp.full_like(scores, self.filter_value) mask_scores = jnp.full_like(scores, self.filter_value)
...@@ -183,9 +176,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper): ...@@ -183,9 +176,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
batch_size, vocab_size = scores.shape batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
...@@ -212,9 +203,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): ...@@ -212,9 +203,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
def __init__(self, bos_token_id: int): def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf")) new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - 1) apply_penalty = 1 - jnp.bool_(cur_len - 1)
...@@ -242,9 +231,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): ...@@ -242,9 +231,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
self.max_length = max_length self.max_length = max_length
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf")) new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
...@@ -277,9 +264,7 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): ...@@ -277,9 +264,7 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
self.min_length = min_length self.min_length = min_length
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
# create boolean flag to decide if min length penalty should be applied # create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
......
...@@ -23,7 +23,6 @@ import numpy as np ...@@ -23,7 +23,6 @@ import numpy as np
import flax import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from jax import lax from jax import lax
from .file_utils import ModelOutput from .file_utils import ModelOutput
...@@ -49,11 +48,11 @@ class FlaxGreedySearchOutput(ModelOutput): ...@@ -49,11 +48,11 @@ class FlaxGreedySearchOutput(ModelOutput):
Args: Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences. The generated sequences.
""" """
sequences: jax_xla.DeviceArray = None sequences: jnp.ndarray = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -63,11 +62,11 @@ class FlaxSampleOutput(ModelOutput): ...@@ -63,11 +62,11 @@ class FlaxSampleOutput(ModelOutput):
Args: Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences. The generated sequences.
""" """
sequences: jax_xla.DeviceArray = None sequences: jnp.ndarray = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -77,44 +76,44 @@ class FlaxBeamSearchOutput(ModelOutput): ...@@ -77,44 +76,44 @@ class FlaxBeamSearchOutput(ModelOutput):
Args: Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences. The generated sequences.
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`): scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size,)`):
The scores (log probabilites) of the generated sequences. The scores (log probabilites) of the generated sequences.
""" """
sequences: jax_xla.DeviceArray = None sequences: jnp.ndarray = None
scores: jax_xla.DeviceArray = None scores: jnp.ndarray = None
@flax.struct.dataclass @flax.struct.dataclass
class GreedyState: class GreedyState:
cur_len: jax_xla.DeviceArray cur_len: jnp.ndarray
sequences: jax_xla.DeviceArray sequences: jnp.ndarray
running_token: jax_xla.DeviceArray running_token: jnp.ndarray
is_sent_finished: jax_xla.DeviceArray is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jax_xla.DeviceArray] model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass @flax.struct.dataclass
class SampleState: class SampleState:
cur_len: jax_xla.DeviceArray cur_len: jnp.ndarray
sequences: jax_xla.DeviceArray sequences: jnp.ndarray
running_token: jax_xla.DeviceArray running_token: jnp.ndarray
is_sent_finished: jax_xla.DeviceArray is_sent_finished: jnp.ndarray
prng_key: jax_xla.DeviceArray prng_key: jnp.ndarray
model_kwargs: Dict[str, jax_xla.DeviceArray] model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass @flax.struct.dataclass
class BeamSearchState: class BeamSearchState:
cur_len: jax_xla.DeviceArray cur_len: jnp.ndarray
running_sequences: jax_xla.DeviceArray running_sequences: jnp.ndarray
running_scores: jax_xla.DeviceArray running_scores: jnp.ndarray
sequences: jax_xla.DeviceArray sequences: jnp.ndarray
scores: jax_xla.DeviceArray scores: jnp.ndarray
is_sent_finished: jax_xla.DeviceArray is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jax_xla.DeviceArray] model_kwargs: Dict[str, jnp.ndarray]
class FlaxGenerationMixin: class FlaxGenerationMixin:
...@@ -156,14 +155,14 @@ class FlaxGenerationMixin: ...@@ -156,14 +155,14 @@ class FlaxGenerationMixin:
def generate( def generate(
self, self,
input_ids: jax_xla.DeviceArray, input_ids: jnp.ndarray,
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None, bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None, prng_key: Optional[jnp.ndarray] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
...@@ -175,7 +174,7 @@ class FlaxGenerationMixin: ...@@ -175,7 +174,7 @@ class FlaxGenerationMixin:
length_penalty: Optional[float] = None, length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None, early_stopping: Optional[bool] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
**model_kwargs, **model_kwargs,
): ):
r""" r"""
...@@ -191,7 +190,7 @@ class FlaxGenerationMixin: ...@@ -191,7 +190,7 @@ class FlaxGenerationMixin:
Parameters: Parameters:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`): input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
...@@ -217,7 +216,7 @@ class FlaxGenerationMixin: ...@@ -217,7 +216,7 @@ class FlaxGenerationMixin:
trace (:obj:`bool`, `optional`, defaults to :obj:`True`): trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime. a considerably slower runtime.
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`): params (:obj:`Dict[str, jnp.ndarray]`, `optional`):
Optionally the model parameters can be passed. Can be useful for parallelized generation. Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
...@@ -395,8 +394,8 @@ class FlaxGenerationMixin: ...@@ -395,8 +394,8 @@ class FlaxGenerationMixin:
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
# init values # init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
...@@ -479,12 +478,12 @@ class FlaxGenerationMixin: ...@@ -479,12 +478,12 @@ class FlaxGenerationMixin:
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
prng_key: Optional[jax_xla.DeviceArray] = None, prng_key: Optional[jnp.ndarray] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None, logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
# init values # init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
...@@ -580,8 +579,8 @@ class FlaxGenerationMixin: ...@@ -580,8 +579,8 @@ class FlaxGenerationMixin:
early_stopping: Optional[bool] = None, early_stopping: Optional[bool] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
""" """
This beam search function is heavily inspired by Flax's official example: This beam search function is heavily inspired by Flax's official example:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import flax import flax
import jaxlib.xla_extension as jax_xla import jax.numpy as jnp
from .file_utils import ModelOutput from .file_utils import ModelOutput
...@@ -25,24 +25,24 @@ class FlaxBaseModelOutput(ModelOutput): ...@@ -25,24 +25,24 @@ class FlaxBaseModelOutput(ModelOutput):
Base class for model's outputs, with potential hidden states and attentions. Base class for model's outputs, with potential hidden states and attentions.
Args: Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
last_hidden_state: jax_xla.DeviceArray = None last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -51,28 +51,28 @@ class FlaxBaseModelOutputWithPast(ModelOutput): ...@@ -51,28 +51,28 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
Base class for model's outputs, with potential hidden states and attentions. Base class for model's outputs, with potential hidden states and attentions.
Args: Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
past_key_values (:obj:`Dict[str, jax_xla.DeviceArray]`): past_key_values (:obj:`Dict[str, jnp.ndarray]`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`. auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
last_hidden_state: jax_xla.DeviceArray = None last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Dict[str, jax_xla.DeviceArray]] = None past_key_values: Optional[Dict[str, jnp.ndarray]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -81,29 +81,29 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): ...@@ -81,29 +81,29 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
Base class for model's outputs that also contains a pooling of the last hidden states. Base class for model's outputs that also contains a pooling of the last hidden states.
Args: Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`): pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining. prediction (classification) objective during pretraining.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
last_hidden_state: jax_xla.DeviceArray = None last_hidden_state: jnp.ndarray = None
pooler_output: jax_xla.DeviceArray = None pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -112,44 +112,44 @@ class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): ...@@ -112,44 +112,44 @@ class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args: Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output. 1, hidden_size)` is output.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads, ``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`. encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see ``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding. :obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
""" """
last_hidden_state: jax_xla.DeviceArray = None last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -159,58 +159,58 @@ class FlaxSeq2SeqModelOutput(ModelOutput): ...@@ -159,58 +159,58 @@ class FlaxSeq2SeqModelOutput(ModelOutput):
decoding. decoding.
Args: Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model. Sequence of hidden-states at the output of the last layer of the decoder of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output. 1, hidden_size)` is output.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
""" """
last_hidden_state: jax_xla.DeviceArray = None last_hidden_state: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -219,39 +219,39 @@ class FlaxCausalLMOutputWithCrossAttentions(ModelOutput): ...@@ -219,39 +219,39 @@ class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
Base class for causal language model (or autoregressive) outputs. Base class for causal language model (or autoregressive) outputs.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads. cross-attention heads.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`jax_xla.DeviceArray` tuples of length :obj:`config.n_layers`, with each tuple containing the Tuple of :obj:`jnp.ndarray` tuples of length :obj:`config.n_layers`, with each tuple containing the cached
cached key, value states of the self-attention and the cross-attention layers if model is used in key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
encoder-decoder setting. Only relevant if ``config.is_decoder = True``. setting. Only relevant if ``config.is_decoder = True``.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding. :obj:`past_key_values` input) to speed up sequential decoding.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -260,24 +260,24 @@ class FlaxMaskedLMOutput(ModelOutput): ...@@ -260,24 +260,24 @@ class FlaxMaskedLMOutput(ModelOutput):
Base class for masked language models outputs. Base class for masked language models outputs.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
FlaxCausalLMOutput = FlaxMaskedLMOutput FlaxCausalLMOutput = FlaxMaskedLMOutput
...@@ -289,55 +289,55 @@ class FlaxSeq2SeqLMOutput(ModelOutput): ...@@ -289,55 +289,55 @@ class FlaxSeq2SeqLMOutput(ModelOutput):
Base class for sequence-to-sequence language models outputs. Base class for sequence-to-sequence language models outputs.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -346,25 +346,25 @@ class FlaxNextSentencePredictorOutput(ModelOutput): ...@@ -346,25 +346,25 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not. Base class for outputs of models predicting if two sentences are consecutive or not.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax). before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -373,24 +373,24 @@ class FlaxSequenceClassifierOutput(ModelOutput): ...@@ -373,24 +373,24 @@ class FlaxSequenceClassifierOutput(ModelOutput):
Base class for outputs of sentence classification models. Base class for outputs of sentence classification models.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax). Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -399,55 +399,55 @@ class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -399,55 +399,55 @@ class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
Base class for outputs of sequence-to-sequence sentence classification models. Base class for outputs of sequence-to-sequence sentence classification models.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax). Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -456,26 +456,26 @@ class FlaxMultipleChoiceModelOutput(ModelOutput): ...@@ -456,26 +456,26 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
Base class for outputs of multiple choice models. Base class for outputs of multiple choice models.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above). `num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax). Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -484,24 +484,24 @@ class FlaxTokenClassifierOutput(ModelOutput): ...@@ -484,24 +484,24 @@ class FlaxTokenClassifierOutput(ModelOutput):
Base class for outputs of token classification models. Base class for outputs of token classification models.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax). Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -510,27 +510,27 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput): ...@@ -510,27 +510,27 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput):
Base class for outputs of question answering models. Base class for outputs of question answering models.
Args: Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
start_logits: jax_xla.DeviceArray = None start_logits: jnp.ndarray = None
end_logits: jax_xla.DeviceArray = None end_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -539,55 +539,55 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -539,55 +539,55 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Base class for outputs of sequence-to-sequence question answering models. Base class for outputs of sequence-to-sequence question answering models.
Args: Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
""" """
start_logits: jax_xla.DeviceArray = None start_logits: jnp.ndarray = None
end_logits: jax_xla.DeviceArray = None end_logits: jnp.ndarray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None cross_attentions: Optional[Tuple[jnp.ndarray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None encoder_last_hidden_state: Optional[jnp.ndarray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
...@@ -21,7 +21,6 @@ import flax ...@@ -21,7 +21,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -61,28 +60,28 @@ class FlaxBertForPreTrainingOutput(ModelOutput): ...@@ -61,28 +60,28 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.BertForPreTraining`. Output type of :class:`~transformers.BertForPreTraining`.
Args: Args:
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax). before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
prediction_logits: jax_xla.DeviceArray = None prediction_logits: jnp.ndarray = None
seq_relationship_logits: jax_xla.DeviceArray = None seq_relationship_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
BERT_START_DOCSTRING = r""" BERT_START_DOCSTRING = r"""
......
...@@ -21,7 +21,6 @@ import flax ...@@ -21,7 +21,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -59,28 +58,28 @@ class FlaxBigBirdForPreTrainingOutput(ModelOutput): ...@@ -59,28 +58,28 @@ class FlaxBigBirdForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.BigBirdForPreTraining`. Output type of :class:`~transformers.BigBirdForPreTraining`.
Args: Args:
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax). before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
prediction_logits: jax_xla.DeviceArray = None prediction_logits: jnp.ndarray = None
seq_relationship_logits: jax_xla.DeviceArray = None seq_relationship_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -89,30 +88,30 @@ class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput): ...@@ -89,30 +88,30 @@ class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput):
Base class for outputs of question answering models. Base class for outputs of question answering models.
Args: Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
pooled_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`): pooled_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
pooled_output returned by FlaxBigBirdModel. pooled_output returned by FlaxBigBirdModel.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
start_logits: jax_xla.DeviceArray = None start_logits: jnp.ndarray = None
end_logits: jax_xla.DeviceArray = None end_logits: jnp.ndarray = None
pooled_output: jax_xla.DeviceArray = None pooled_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
BIG_BIRD_START_DOCSTRING = r""" BIG_BIRD_START_DOCSTRING = r"""
......
...@@ -19,7 +19,6 @@ import flax ...@@ -19,7 +19,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
...@@ -156,16 +155,16 @@ CLIP_INPUTS_DOCSTRING = r""" ...@@ -156,16 +155,16 @@ CLIP_INPUTS_DOCSTRING = r"""
class FlaxCLIPOutput(ModelOutput): class FlaxCLIPOutput(ModelOutput):
""" """
Args: Args:
logits_per_image:(:obj:`jax_xla.DeviceArray` of shape :obj:`(image_batch_size, text_batch_size)`): logits_per_image:(:obj:`jnp.ndarray` of shape :obj:`(image_batch_size, text_batch_size)`):
The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the
image-text similarity scores. image-text similarity scores.
logits_per_text:(:obj:`jax_xla.DeviceArray` of shape :obj:`(text_batch_size, image_batch_size)`): logits_per_text:(:obj:`jnp.ndarray` of shape :obj:`(text_batch_size, image_batch_size)`):
The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the
text-image similarity scores. text-image similarity scores.
text_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): text_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of The text embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPTextModel`. :class:`~transformers.FlaxCLIPTextModel`.
image_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): image_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of The image embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPVisionModel`. :class:`~transformers.FlaxCLIPVisionModel`.
text_model_output(:obj:`FlaxBaseModelOutputWithPooling`): text_model_output(:obj:`FlaxBaseModelOutputWithPooling`):
...@@ -174,10 +173,10 @@ class FlaxCLIPOutput(ModelOutput): ...@@ -174,10 +173,10 @@ class FlaxCLIPOutput(ModelOutput):
The output of the :class:`~transformers.FlaxCLIPVisionModel`. The output of the :class:`~transformers.FlaxCLIPVisionModel`.
""" """
logits_per_image: jax_xla.DeviceArray = None logits_per_image: jnp.ndarray = None
logits_per_text: jax_xla.DeviceArray = None logits_per_text: jnp.ndarray = None
text_embeds: jax_xla.DeviceArray = None text_embeds: jnp.ndarray = None
image_embeds: jax_xla.DeviceArray = None image_embeds: jnp.ndarray = None
text_model_output: FlaxBaseModelOutputWithPooling = None text_model_output: FlaxBaseModelOutputWithPooling = None
vision_model_output: FlaxBaseModelOutputWithPooling = None vision_model_output: FlaxBaseModelOutputWithPooling = None
...@@ -801,8 +800,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -801,8 +800,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
Returns: Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by
obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`. applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.
Examples:: Examples::
...@@ -855,9 +854,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -855,9 +854,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
:meth:`transformers.CLIPFeatureExtractor.__call__` for details. :meth:`transformers.CLIPFeatureExtractor.__call__` for details.
Returns: Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained
obtained by applying the projection layer to the pooled output of by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPVisionModel`
:class:`~transformers.FlaxCLIPVisionModel`
Examples:: Examples::
......
...@@ -21,7 +21,6 @@ import flax ...@@ -21,7 +21,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -60,24 +59,24 @@ class FlaxElectraForPreTrainingOutput(ModelOutput): ...@@ -60,24 +59,24 @@ class FlaxElectraForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.ElectraForPreTraining`. Output type of :class:`~transformers.ElectraForPreTraining`.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
ELECTRA_START_DOCSTRING = r""" ELECTRA_START_DOCSTRING = r"""
......
...@@ -44,7 +44,6 @@ if is_flax_available(): ...@@ -44,7 +44,6 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import unfreeze from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict
from transformers import ( from transformers import (
...@@ -127,7 +126,7 @@ class FlaxModelTesterMixin: ...@@ -127,7 +126,7 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__: if "ForMultipleChoice" in model_class.__name__:
inputs_dict = { inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
if isinstance(v, (jax_xla.DeviceArray, np.ndarray)) if isinstance(v, (jnp.ndarray, np.ndarray))
else v else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment