Unverified Commit da9754a3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Align jax flax device name (#12987)

* [Flax] Align device name in docs

* make style

* fix import error
parent 07df5578
...@@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
Returns: Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of text model. obtained by applying the projection layer to the pooled output of text model.
""" """
if position_ids is None: if position_ids is None:
...@@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): ...@@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details. :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
Returns: Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of vision model. obtained by applying the projection layer to the pooled output of vision model.
""" """
......
...@@ -19,7 +19,6 @@ from abc import ABC ...@@ -19,7 +19,6 @@ from abc import ABC
import jax import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .utils.logging import get_logger from .utils.logging import get_logger
...@@ -30,7 +29,7 @@ logger = get_logger(__name__) ...@@ -30,7 +29,7 @@ logger = get_logger(__name__)
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
...@@ -38,14 +37,14 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" ...@@ -38,14 +37,14 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`): scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search search or log softmax for each vocabulary token when using beam search
kwargs: kwargs:
Additional logits processor specific kwargs. Additional logits processor specific kwargs.
Return: Return:
:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. :obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
""" """
...@@ -54,7 +53,7 @@ class FlaxLogitsProcessor(ABC): ...@@ -54,7 +53,7 @@ class FlaxLogitsProcessor(ABC):
"""Abstract base class for all logit processors that can be applied during generation.""" """Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""Flax method for processing logits.""" """Flax method for processing logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
...@@ -65,7 +64,7 @@ class FlaxLogitsWarper(ABC): ...@@ -65,7 +64,7 @@ class FlaxLogitsWarper(ABC):
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
"""Flax method for warping logits.""" """Flax method for warping logits."""
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
...@@ -81,9 +80,7 @@ class FlaxLogitsProcessorList(list): ...@@ -81,9 +80,7 @@ class FlaxLogitsProcessorList(list):
""" """
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
) -> jax_xla.DeviceArray:
for processor in self: for processor in self:
function_args = inspect.signature(processor.__call__).parameters function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 3: if len(function_args) > 3:
...@@ -111,9 +108,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): ...@@ -111,9 +108,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
self.temperature = temperature self.temperature = temperature
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
scores = scores / self.temperature scores = scores / self.temperature
return scores return scores
...@@ -141,9 +136,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): ...@@ -141,9 +136,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1]) topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
mask_scores = jnp.full_like(scores, self.filter_value) mask_scores = jnp.full_like(scores, self.filter_value)
...@@ -183,9 +176,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper): ...@@ -183,9 +176,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
batch_size, vocab_size = scores.shape batch_size, vocab_size = scores.shape
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
...@@ -212,9 +203,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): ...@@ -212,9 +203,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
def __init__(self, bos_token_id: int): def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf")) new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - 1) apply_penalty = 1 - jnp.bool_(cur_len - 1)
...@@ -242,9 +231,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): ...@@ -242,9 +231,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
self.max_length = max_length self.max_length = max_length
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
new_scores = jnp.full(scores.shape, -float("inf")) new_scores = jnp.full(scores.shape, -float("inf"))
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
...@@ -277,9 +264,7 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): ...@@ -277,9 +264,7 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
self.min_length = min_length self.min_length = min_length
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
def __call__( def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
) -> jax_xla.DeviceArray:
# create boolean flag to decide if min length penalty should be applied # create boolean flag to decide if min length penalty should be applied
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
......
...@@ -23,7 +23,6 @@ import numpy as np ...@@ -23,7 +23,6 @@ import numpy as np
import flax import flax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from jax import lax from jax import lax
from .file_utils import ModelOutput from .file_utils import ModelOutput
...@@ -49,11 +48,11 @@ class FlaxGreedySearchOutput(ModelOutput): ...@@ -49,11 +48,11 @@ class FlaxGreedySearchOutput(ModelOutput):
Args: Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences. The generated sequences.
""" """
sequences: jax_xla.DeviceArray = None sequences: jnp.ndarray = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -63,11 +62,11 @@ class FlaxSampleOutput(ModelOutput): ...@@ -63,11 +62,11 @@ class FlaxSampleOutput(ModelOutput):
Args: Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences. The generated sequences.
""" """
sequences: jax_xla.DeviceArray = None sequences: jnp.ndarray = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -77,44 +76,44 @@ class FlaxBeamSearchOutput(ModelOutput): ...@@ -77,44 +76,44 @@ class FlaxBeamSearchOutput(ModelOutput):
Args: Args:
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
The generated sequences. The generated sequences.
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`): scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size,)`):
The scores (log probabilites) of the generated sequences. The scores (log probabilites) of the generated sequences.
""" """
sequences: jax_xla.DeviceArray = None sequences: jnp.ndarray = None
scores: jax_xla.DeviceArray = None scores: jnp.ndarray = None
@flax.struct.dataclass @flax.struct.dataclass
class GreedyState: class GreedyState:
cur_len: jax_xla.DeviceArray cur_len: jnp.ndarray
sequences: jax_xla.DeviceArray sequences: jnp.ndarray
running_token: jax_xla.DeviceArray running_token: jnp.ndarray
is_sent_finished: jax_xla.DeviceArray is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jax_xla.DeviceArray] model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass @flax.struct.dataclass
class SampleState: class SampleState:
cur_len: jax_xla.DeviceArray cur_len: jnp.ndarray
sequences: jax_xla.DeviceArray sequences: jnp.ndarray
running_token: jax_xla.DeviceArray running_token: jnp.ndarray
is_sent_finished: jax_xla.DeviceArray is_sent_finished: jnp.ndarray
prng_key: jax_xla.DeviceArray prng_key: jnp.ndarray
model_kwargs: Dict[str, jax_xla.DeviceArray] model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass @flax.struct.dataclass
class BeamSearchState: class BeamSearchState:
cur_len: jax_xla.DeviceArray cur_len: jnp.ndarray
running_sequences: jax_xla.DeviceArray running_sequences: jnp.ndarray
running_scores: jax_xla.DeviceArray running_scores: jnp.ndarray
sequences: jax_xla.DeviceArray sequences: jnp.ndarray
scores: jax_xla.DeviceArray scores: jnp.ndarray
is_sent_finished: jax_xla.DeviceArray is_sent_finished: jnp.ndarray
model_kwargs: Dict[str, jax_xla.DeviceArray] model_kwargs: Dict[str, jnp.ndarray]
class FlaxGenerationMixin: class FlaxGenerationMixin:
...@@ -156,14 +155,14 @@ class FlaxGenerationMixin: ...@@ -156,14 +155,14 @@ class FlaxGenerationMixin:
def generate( def generate(
self, self,
input_ids: jax_xla.DeviceArray, input_ids: jnp.ndarray,
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None, bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None, prng_key: Optional[jnp.ndarray] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
...@@ -175,7 +174,7 @@ class FlaxGenerationMixin: ...@@ -175,7 +174,7 @@ class FlaxGenerationMixin:
length_penalty: Optional[float] = None, length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None, early_stopping: Optional[bool] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
**model_kwargs, **model_kwargs,
): ):
r""" r"""
...@@ -191,7 +190,7 @@ class FlaxGenerationMixin: ...@@ -191,7 +190,7 @@ class FlaxGenerationMixin:
Parameters: Parameters:
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`): input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. The sequence used as a prompt for the generation.
max_length (:obj:`int`, `optional`, defaults to 20): max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
...@@ -217,7 +216,7 @@ class FlaxGenerationMixin: ...@@ -217,7 +216,7 @@ class FlaxGenerationMixin:
trace (:obj:`bool`, `optional`, defaults to :obj:`True`): trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime. a considerably slower runtime.
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`): params (:obj:`Dict[str, jnp.ndarray]`, `optional`):
Optionally the model parameters can be passed. Can be useful for parallelized generation. Optionally the model parameters can be passed. Can be useful for parallelized generation.
model_kwargs: model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
...@@ -395,8 +394,8 @@ class FlaxGenerationMixin: ...@@ -395,8 +394,8 @@ class FlaxGenerationMixin:
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
# init values # init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
...@@ -479,12 +478,12 @@ class FlaxGenerationMixin: ...@@ -479,12 +478,12 @@ class FlaxGenerationMixin:
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
prng_key: Optional[jax_xla.DeviceArray] = None, prng_key: Optional[jnp.ndarray] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None,
logits_warper: Optional[FlaxLogitsProcessorList] = None, logits_warper: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
# init values # init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
...@@ -580,8 +579,8 @@ class FlaxGenerationMixin: ...@@ -580,8 +579,8 @@ class FlaxGenerationMixin:
early_stopping: Optional[bool] = None, early_stopping: Optional[bool] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True, trace: bool = True,
params: Optional[Dict[str, jax_xla.DeviceArray]] = None, params: Optional[Dict[str, jnp.ndarray]] = None,
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
): ):
""" """
This beam search function is heavily inspired by Flax's official example: This beam search function is heavily inspired by Flax's official example:
......
This diff is collapsed.
...@@ -21,7 +21,6 @@ import flax ...@@ -21,7 +21,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -61,28 +60,28 @@ class FlaxBertForPreTrainingOutput(ModelOutput): ...@@ -61,28 +60,28 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.BertForPreTraining`. Output type of :class:`~transformers.BertForPreTraining`.
Args: Args:
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax). before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
prediction_logits: jax_xla.DeviceArray = None prediction_logits: jnp.ndarray = None
seq_relationship_logits: jax_xla.DeviceArray = None seq_relationship_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
BERT_START_DOCSTRING = r""" BERT_START_DOCSTRING = r"""
......
...@@ -21,7 +21,6 @@ import flax ...@@ -21,7 +21,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -59,28 +58,28 @@ class FlaxBigBirdForPreTrainingOutput(ModelOutput): ...@@ -59,28 +58,28 @@ class FlaxBigBirdForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.BigBirdForPreTraining`. Output type of :class:`~transformers.BigBirdForPreTraining`.
Args: Args:
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax). before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
prediction_logits: jax_xla.DeviceArray = None prediction_logits: jnp.ndarray = None
seq_relationship_logits: jax_xla.DeviceArray = None seq_relationship_logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
...@@ -89,30 +88,30 @@ class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput): ...@@ -89,30 +88,30 @@ class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput):
Base class for outputs of question answering models. Base class for outputs of question answering models.
Args: Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax). Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax). Span-end scores (before SoftMax).
pooled_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`): pooled_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
pooled_output returned by FlaxBigBirdModel. pooled_output returned by FlaxBigBirdModel.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
start_logits: jax_xla.DeviceArray = None start_logits: jnp.ndarray = None
end_logits: jax_xla.DeviceArray = None end_logits: jnp.ndarray = None
pooled_output: jax_xla.DeviceArray = None pooled_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
BIG_BIRD_START_DOCSTRING = r""" BIG_BIRD_START_DOCSTRING = r"""
......
...@@ -19,7 +19,6 @@ import flax ...@@ -19,7 +19,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import combine_masks, make_causal_mask from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
...@@ -156,16 +155,16 @@ CLIP_INPUTS_DOCSTRING = r""" ...@@ -156,16 +155,16 @@ CLIP_INPUTS_DOCSTRING = r"""
class FlaxCLIPOutput(ModelOutput): class FlaxCLIPOutput(ModelOutput):
""" """
Args: Args:
logits_per_image:(:obj:`jax_xla.DeviceArray` of shape :obj:`(image_batch_size, text_batch_size)`): logits_per_image:(:obj:`jnp.ndarray` of shape :obj:`(image_batch_size, text_batch_size)`):
The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the
image-text similarity scores. image-text similarity scores.
logits_per_text:(:obj:`jax_xla.DeviceArray` of shape :obj:`(text_batch_size, image_batch_size)`): logits_per_text:(:obj:`jnp.ndarray` of shape :obj:`(text_batch_size, image_batch_size)`):
The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the
text-image similarity scores. text-image similarity scores.
text_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): text_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of The text embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPTextModel`. :class:`~transformers.FlaxCLIPTextModel`.
image_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): image_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of The image embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPVisionModel`. :class:`~transformers.FlaxCLIPVisionModel`.
text_model_output(:obj:`FlaxBaseModelOutputWithPooling`): text_model_output(:obj:`FlaxBaseModelOutputWithPooling`):
...@@ -174,10 +173,10 @@ class FlaxCLIPOutput(ModelOutput): ...@@ -174,10 +173,10 @@ class FlaxCLIPOutput(ModelOutput):
The output of the :class:`~transformers.FlaxCLIPVisionModel`. The output of the :class:`~transformers.FlaxCLIPVisionModel`.
""" """
logits_per_image: jax_xla.DeviceArray = None logits_per_image: jnp.ndarray = None
logits_per_text: jax_xla.DeviceArray = None logits_per_text: jnp.ndarray = None
text_embeds: jax_xla.DeviceArray = None text_embeds: jnp.ndarray = None
image_embeds: jax_xla.DeviceArray = None image_embeds: jnp.ndarray = None
text_model_output: FlaxBaseModelOutputWithPooling = None text_model_output: FlaxBaseModelOutputWithPooling = None
vision_model_output: FlaxBaseModelOutputWithPooling = None vision_model_output: FlaxBaseModelOutputWithPooling = None
...@@ -801,8 +800,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -801,8 +800,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
Returns: Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by
obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`. applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.
Examples:: Examples::
...@@ -855,9 +854,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ...@@ -855,9 +854,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
:meth:`transformers.CLIPFeatureExtractor.__call__` for details. :meth:`transformers.CLIPFeatureExtractor.__call__` for details.
Returns: Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained
obtained by applying the projection layer to the pooled output of by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPVisionModel`
:class:`~transformers.FlaxCLIPVisionModel`
Examples:: Examples::
......
...@@ -21,7 +21,6 @@ import flax ...@@ -21,7 +21,6 @@ import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
...@@ -60,24 +59,24 @@ class FlaxElectraForPreTrainingOutput(ModelOutput): ...@@ -60,24 +59,24 @@ class FlaxElectraForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.ElectraForPreTraining`. Output type of :class:`~transformers.ElectraForPreTraining`.
Args: Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length, sequence_length)`. sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
""" """
logits: jax_xla.DeviceArray = None logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jnp.ndarray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
ELECTRA_START_DOCSTRING = r""" ELECTRA_START_DOCSTRING = r"""
......
...@@ -44,7 +44,6 @@ if is_flax_available(): ...@@ -44,7 +44,6 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import unfreeze from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict from flax.traverse_util import flatten_dict
from transformers import ( from transformers import (
...@@ -127,7 +126,7 @@ class FlaxModelTesterMixin: ...@@ -127,7 +126,7 @@ class FlaxModelTesterMixin:
if "ForMultipleChoice" in model_class.__name__: if "ForMultipleChoice" in model_class.__name__:
inputs_dict = { inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
if isinstance(v, (jax_xla.DeviceArray, np.ndarray)) if isinstance(v, (jnp.ndarray, np.ndarray))
else v else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment