Unverified Commit f748bd42 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add docstrings & model outputs (#11498)



* add attentions & hidden states

* add model outputs + docs

* finish docs

* finish tests

* finish impl

* del @

* finish

* finish

* correct test

* apply sylvains suggestions

* Update src/transformers/models/bert/modeling_flax_bert.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* simplify more
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 3f6add8b
......@@ -794,6 +794,17 @@ PT_CAUSAL_LM_SAMPLE = r"""
>>> logits = outputs.logits
"""
PT_SAMPLE_DOCSTRINGS = {
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": PT_MASKED_LM_SAMPLE,
"LMHead": PT_CAUSAL_LM_SAMPLE,
"BaseModel": PT_BASE_MODEL_SAMPLE,
}
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
......@@ -915,30 +926,148 @@ TF_CAUSAL_LM_SAMPLE = r"""
>>> logits = outputs.logits
"""
TF_SAMPLE_DOCSTRINGS = {
"SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": TF_MASKED_LM_SAMPLE,
"LMHead": TF_CAUSAL_LM_SAMPLE,
"BaseModel": TF_BASE_MODEL_SAMPLE,
}
FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
FLAX_QUESTION_ANSWERING_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors='jax')
>>> outputs = model(**inputs)
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
"""
FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs, labels=labels)
>>> logits = outputs.logits
"""
FLAX_MASKED_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors='jax')
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
FLAX_BASE_MODEL_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='jax', padding=True)
>>> outputs = model(**{{k: v[None, :] for k,v in encoding.items()}})
>>> logits = outputs.logits
"""
FLAX_SAMPLE_DOCSTRINGS = {
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
}
def add_code_sample_docstrings(
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None
):
def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0]
is_tf_class = model_class[:2] == "TF"
# model_class defaults to function's class if not specified otherwise
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
if model_class[:2] == "TF":
sample_docstrings = TF_SAMPLE_DOCSTRINGS
elif model_class[:4] == "Flax":
sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
else:
sample_docstrings = PT_SAMPLE_DOCSTRINGS
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
if "SequenceClassification" in model_class:
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
code_sample = sample_docstrings["SequenceClassification"]
elif "QuestionAnswering" in model_class:
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
code_sample = sample_docstrings["QuestionAnswering"]
elif "TokenClassification" in model_class:
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
code_sample = sample_docstrings["TokenClassification"]
elif "MultipleChoice" in model_class:
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
code_sample = sample_docstrings["MultipleChoice"]
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
code_sample = sample_docstrings["MaskedLM"]
elif "LMHead" in model_class or "CausalLM" in model_class:
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
code_sample = sample_docstrings["LMHead"]
elif "Model" in model_class or "Encoder" in model_class:
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
code_sample = sample_docstrings["BaseModel"]
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
......@@ -1462,7 +1591,10 @@ def tf_required(func):
def is_tensor(x):
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`."""
"""
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
:obj:`np.ndarray`.
"""
if is_torch_available():
import torch
......@@ -1473,6 +1605,14 @@ def is_tensor(x):
if isinstance(x, tf.Tensor):
return True
if is_flax_available():
import jaxlib.xla_extension as jax_xla
from jax.interpreters.partial_eval import DynamicJaxprTracer
if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
return True
return isinstance(x, np.ndarray)
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
import jaxlib.xla_extension as jax_xla
from .file_utils import ModelOutput
@dataclass
class FlaxBaseModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
pooler_output: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: jax_xla.DeviceArray = None
end_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
......@@ -32,12 +32,14 @@ from .file_utils import (
FLAX_WEIGHTS_NAME,
WEIGHTS_NAME,
PushToHubMixin,
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
cached_path,
copy_func,
hf_bucket_url,
is_offline_mode,
is_remote_url,
replace_return_docstrings,
)
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging
......@@ -432,3 +434,22 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__.__doc__ = None
# set correct docstring
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
tokenizer_class=tokenizer_class,
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
model_cls=model_class.__name__,
)(model_class.__call__)
def append_replace_return_docstrings(model_class, output_type, config_class):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = replace_return_docstrings(
output_type=output_type,
config_class=config_class,
)(model_class.__call__)
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Optional, Tuple
import flax.linen as nn
import jax
......@@ -23,13 +23,15 @@ from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging
from .configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "roberta-base"
_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
......@@ -181,7 +183,7 @@ class FlaxRobertaSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
......@@ -223,7 +225,12 @@ class FlaxRobertaSelfAttention(nn.Module):
precision=None,
)
return attn_output.reshape(attn_output.shape[:2] + (-1,))
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
......@@ -256,13 +263,22 @@ class FlaxRobertaAttention(nn.Module):
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
attn_outputs = self.self(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1]
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
......@@ -315,11 +331,20 @@ class FlaxRobertaLayer(nn.Module):
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
attention_outputs = self.attention(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attention_output = attention_outputs[0]
hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
......@@ -332,10 +357,40 @@ class FlaxRobertaLayerCollection(nn.Module):
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
]
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
return hidden_states
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
......@@ -346,8 +401,23 @@ class FlaxRobertaEncoder(nn.Module):
def setup(self):
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
......@@ -412,7 +482,21 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params: dict = None,
dropout_rng: PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned." "Please set `output_attentions` to False for now."
)
# init input tensors if not passed
if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids)
......@@ -435,6 +519,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
......@@ -450,17 +537,43 @@ class FlaxRobertaModule(nn.Module):
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
)
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
if not self.add_pooling_layer:
return hidden_states
pooled = self.pooler(hidden_states)
return hidden_states, pooled
outputs = self.encoder(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
......@@ -469,3 +582,8 @@ class FlaxRobertaModule(nn.Module):
)
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaModule
append_call_sample_docstring(
FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
)
......@@ -998,7 +998,6 @@ class ModelTesterMixin:
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
......
......@@ -13,8 +13,10 @@
# limitations under the License.
import copy
import inspect
import random
import tempfile
from typing import List, Tuple
import numpy as np
......@@ -28,6 +30,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
......@@ -77,6 +80,7 @@ class FlaxModelTesterMixin:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
for k, v in inputs_dict.items()
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
}
return inputs_dict
......@@ -85,6 +89,41 @@ class FlaxModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assert_almost_equals(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -108,7 +147,7 @@ class FlaxModelTesterMixin:
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
......@@ -117,7 +156,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
......@@ -149,7 +188,7 @@ class FlaxModelTesterMixin:
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
......@@ -171,17 +210,20 @@ class FlaxModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class.__name__ != "FlaxBertModel":
continue
with self.subTest(model_class.__name__):
model = model_class(config)
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs_dict)
outputs = model(**prepared_inputs_dict).to_tuple()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**prepared_inputs_dict)
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
......@@ -195,19 +237,47 @@ class FlaxModelTesterMixin:
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
return model(input_ids, attention_mask, token_type_ids)
return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
).to_tuple()
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
# jitted function cannot return OrderedDict
with self.assertRaises(TypeError):
model_jitted_return_dict(**prepared_inputs_dict)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_ids", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__
......@@ -218,3 +288,30 @@ class FlaxModelTesterMixin:
module_cls = getattr(bert_modeling_flax_module, module_class_name)
self.assertIsNotNone(module_cls)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment