Unverified Commit f748bd42 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Add docstrings & model outputs (#11498)



* add attentions & hidden states

* add model outputs + docs

* finish docs

* finish tests

* finish impl

* del @

* finish

* finish

* correct test

* apply sylvains suggestions

* Update src/transformers/models/bert/modeling_flax_bert.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* simplify more
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 3f6add8b
...@@ -794,6 +794,17 @@ PT_CAUSAL_LM_SAMPLE = r""" ...@@ -794,6 +794,17 @@ PT_CAUSAL_LM_SAMPLE = r"""
>>> logits = outputs.logits >>> logits = outputs.logits
""" """
PT_SAMPLE_DOCSTRINGS = {
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": PT_MASKED_LM_SAMPLE,
"LMHead": PT_CAUSAL_LM_SAMPLE,
"BaseModel": PT_BASE_MODEL_SAMPLE,
}
TF_TOKEN_CLASSIFICATION_SAMPLE = r""" TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example:: Example::
...@@ -915,30 +926,148 @@ TF_CAUSAL_LM_SAMPLE = r""" ...@@ -915,30 +926,148 @@ TF_CAUSAL_LM_SAMPLE = r"""
>>> logits = outputs.logits >>> logits = outputs.logits
""" """
TF_SAMPLE_DOCSTRINGS = {
"SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": TF_MASKED_LM_SAMPLE,
"LMHead": TF_CAUSAL_LM_SAMPLE,
"BaseModel": TF_BASE_MODEL_SAMPLE,
}
FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
FLAX_QUESTION_ANSWERING_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors='jax')
>>> outputs = model(**inputs)
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
"""
FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs, labels=labels)
>>> logits = outputs.logits
"""
FLAX_MASKED_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors='jax')
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
FLAX_BASE_MODEL_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='jax', padding=True)
>>> outputs = model(**{{k: v[None, :] for k,v in encoding.items()}})
>>> logits = outputs.logits
"""
FLAX_SAMPLE_DOCSTRINGS = {
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
}
def add_code_sample_docstrings( def add_code_sample_docstrings(
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None *docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None
): ):
def docstring_decorator(fn): def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0] # model_class defaults to function's class if not specified otherwise
is_tf_class = model_class[:2] == "TF" model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
if model_class[:2] == "TF":
sample_docstrings = TF_SAMPLE_DOCSTRINGS
elif model_class[:4] == "Flax":
sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
else:
sample_docstrings = PT_SAMPLE_DOCSTRINGS
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint) doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
if "SequenceClassification" in model_class: if "SequenceClassification" in model_class:
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE code_sample = sample_docstrings["SequenceClassification"]
elif "QuestionAnswering" in model_class: elif "QuestionAnswering" in model_class:
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE code_sample = sample_docstrings["QuestionAnswering"]
elif "TokenClassification" in model_class: elif "TokenClassification" in model_class:
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE code_sample = sample_docstrings["TokenClassification"]
elif "MultipleChoice" in model_class: elif "MultipleChoice" in model_class:
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE code_sample = sample_docstrings["MultipleChoice"]
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]: elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
doc_kwargs["mask"] = "[MASK]" if mask is None else mask doc_kwargs["mask"] = "[MASK]" if mask is None else mask
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE code_sample = sample_docstrings["MaskedLM"]
elif "LMHead" in model_class or "CausalLM" in model_class: elif "LMHead" in model_class or "CausalLM" in model_class:
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE code_sample = sample_docstrings["LMHead"]
elif "Model" in model_class or "Encoder" in model_class: elif "Model" in model_class or "Encoder" in model_class:
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE code_sample = sample_docstrings["BaseModel"]
else: else:
raise ValueError(f"Docstring can't be built for model {model_class}") raise ValueError(f"Docstring can't be built for model {model_class}")
...@@ -1462,7 +1591,10 @@ def tf_required(func): ...@@ -1462,7 +1591,10 @@ def tf_required(func):
def is_tensor(x): def is_tensor(x):
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`.""" """
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
:obj:`np.ndarray`.
"""
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -1473,6 +1605,14 @@ def is_tensor(x): ...@@ -1473,6 +1605,14 @@ def is_tensor(x):
if isinstance(x, tf.Tensor): if isinstance(x, tf.Tensor):
return True return True
if is_flax_available():
import jaxlib.xla_extension as jax_xla
from jax.interpreters.partial_eval import DynamicJaxprTracer
if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
return True
return isinstance(x, np.ndarray) return isinstance(x, np.ndarray)
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
import jaxlib.xla_extension as jax_xla
from .file_utils import ModelOutput
@dataclass
class FlaxBaseModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
pooler_output: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: jax_xla.DeviceArray = None
end_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
...@@ -32,12 +32,14 @@ from .file_utils import ( ...@@ -32,12 +32,14 @@ from .file_utils import (
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
PushToHubMixin, PushToHubMixin,
add_code_sample_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
cached_path, cached_path,
copy_func, copy_func,
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
replace_return_docstrings,
) )
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging from .utils import logging
...@@ -432,3 +434,22 @@ def overwrite_call_docstring(model_class, docstring): ...@@ -432,3 +434,22 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__.__doc__ = None model_class.__call__.__doc__ = None
# set correct docstring # set correct docstring
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings(
tokenizer_class=tokenizer_class,
checkpoint=checkpoint,
output_type=output_type,
config_class=config_class,
model_cls=model_class.__name__,
)(model_class.__call__)
def append_replace_return_docstrings(model_class, output_type, config_class):
model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = replace_return_docstrings(
output_type=output_type,
config_class=config_class,
)(model_class.__call__)
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple from typing import Optional, Tuple
import flax.linen as nn import flax.linen as nn
import jax import jax
...@@ -23,13 +23,15 @@ from jax import lax ...@@ -23,13 +23,15 @@ from jax import lax
from jax.random import PRNGKey from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging from ...utils import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "roberta-base"
_CONFIG_FOR_DOC = "RobertaConfig" _CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer" _TOKENIZER_FOR_DOC = "RobertaTokenizer"
...@@ -181,7 +183,7 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -181,7 +183,7 @@ class FlaxRobertaSelfAttention(nn.Module):
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
) )
def __call__(self, hidden_states, attention_mask, deterministic=True): def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
head_dim = self.config.hidden_size // self.config.num_attention_heads head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape( query_states = self.query(hidden_states).reshape(
...@@ -223,7 +225,12 @@ class FlaxRobertaSelfAttention(nn.Module): ...@@ -223,7 +225,12 @@ class FlaxRobertaSelfAttention(nn.Module):
precision=None, precision=None,
) )
return attn_output.reshape(attn_output.shape[:2] + (-1,)) outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
# TODO: at the moment it's not possible to retrieve attn_weights from
# dot_product_attention, but should be in the future -> add functionality then
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
...@@ -256,13 +263,22 @@ class FlaxRobertaAttention(nn.Module): ...@@ -256,13 +263,22 @@ class FlaxRobertaAttention(nn.Module):
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype) self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True): def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic) attn_outputs = self.self(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attn_output = attn_outputs[0]
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1]
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
...@@ -315,11 +331,20 @@ class FlaxRobertaLayer(nn.Module): ...@@ -315,11 +331,20 @@ class FlaxRobertaLayer(nn.Module):
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) attention_outputs = self.attention(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
attention_output = attention_outputs[0]
hidden_states = self.intermediate(attention_output) hidden_states = self.intermediate(attention_output)
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
return hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_outputs[1],)
return outputs
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
...@@ -332,10 +357,40 @@ class FlaxRobertaLayerCollection(nn.Module): ...@@ -332,10 +357,40 @@ class FlaxRobertaLayerCollection(nn.Module):
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
] ]
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) if output_hidden_states:
return hidden_states all_hidden_states += (hidden_states,)
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
...@@ -346,8 +401,23 @@ class FlaxRobertaEncoder(nn.Module): ...@@ -346,8 +401,23 @@ class FlaxRobertaEncoder(nn.Module):
def setup(self): def setup(self):
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype) self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True): def __call__(
return self.layer(hidden_states, attention_mask, deterministic=deterministic) self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layer(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
...@@ -412,7 +482,21 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -412,7 +482,21 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
params: dict = None, params: dict = None,
dropout_rng: PRNGKey = None, dropout_rng: PRNGKey = None,
train: bool = False, train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if output_attentions:
raise NotImplementedError(
"Currently attention scores cannot be returned." "Please set `output_attentions` to False for now."
)
# init input tensors if not passed # init input tensors if not passed
if token_type_ids is None: if token_type_ids is None:
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.ones_like(input_ids)
...@@ -435,6 +519,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ...@@ -435,6 +519,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
jnp.array(token_type_ids, dtype="i4"), jnp.array(token_type_ids, dtype="i4"),
jnp.array(position_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"),
not train, not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs, rngs=rngs,
) )
...@@ -450,17 +537,43 @@ class FlaxRobertaModule(nn.Module): ...@@ -450,17 +537,43 @@ class FlaxRobertaModule(nn.Module):
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype) self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): def __call__(
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = self.embeddings( hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
) )
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic) outputs = self.encoder(
hidden_states,
if not self.add_pooling_layer: attention_mask,
return hidden_states deterministic=deterministic,
output_attentions=output_attentions,
pooled = self.pooler(hidden_states) output_hidden_states=output_hidden_states,
return hidden_states, pooled return_dict=return_dict,
)
hidden_states = outputs[0]
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
if not return_dict:
# if pooled is None, don't return it
if pooled is None:
return (hidden_states,) + outputs[1:]
return (hidden_states, pooled) + outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=pooled,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings( @add_start_docstrings(
...@@ -469,3 +582,8 @@ class FlaxRobertaModule(nn.Module): ...@@ -469,3 +582,8 @@ class FlaxRobertaModule(nn.Module):
) )
class FlaxRobertaModel(FlaxRobertaPreTrainedModel): class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaModule module_class = FlaxRobertaModule
append_call_sample_docstring(
FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
)
...@@ -998,7 +998,6 @@ class ModelTesterMixin: ...@@ -998,7 +998,6 @@ class ModelTesterMixin:
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t): def set_nan_tensor_to_zero(t):
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
import copy import copy
import inspect
import random import random
import tempfile import tempfile
from typing import List, Tuple
import numpy as np import numpy as np
...@@ -28,6 +30,7 @@ if is_flax_available(): ...@@ -28,6 +30,7 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from transformers.modeling_flax_pytorch_utils import ( from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax, convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model, load_flax_weights_in_pytorch_model,
...@@ -77,6 +80,7 @@ class FlaxModelTesterMixin: ...@@ -77,6 +80,7 @@ class FlaxModelTesterMixin:
inputs_dict = { inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
} }
return inputs_dict return inputs_dict
...@@ -85,6 +89,41 @@ class FlaxModelTesterMixin: ...@@ -85,6 +89,41 @@ class FlaxModelTesterMixin:
diff = np.abs((a - b)).max() diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assert_almost_equals(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self): def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -108,7 +147,7 @@ class FlaxModelTesterMixin: ...@@ -108,7 +147,7 @@ class FlaxModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict) fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
...@@ -117,7 +156,7 @@ class FlaxModelTesterMixin: ...@@ -117,7 +156,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual( self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
) )
...@@ -149,7 +188,7 @@ class FlaxModelTesterMixin: ...@@ -149,7 +188,7 @@ class FlaxModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict) fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
...@@ -171,17 +210,20 @@ class FlaxModelTesterMixin: ...@@ -171,17 +210,20 @@ class FlaxModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class.__name__ != "FlaxBertModel":
continue
with self.subTest(model_class.__name__): with self.subTest(model_class.__name__):
model = model_class(config) model = model_class(config)
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs_dict) outputs = model(**prepared_inputs_dict).to_tuple()
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname) model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**prepared_inputs_dict) outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
for output_loaded, output in zip(outputs_loaded, outputs): for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3) self.assert_almost_equals(output_loaded, output, 1e-3)
...@@ -195,19 +237,47 @@ class FlaxModelTesterMixin: ...@@ -195,19 +237,47 @@ class FlaxModelTesterMixin:
@jax.jit @jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None): def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
return model(input_ids, attention_mask, token_type_ids) return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
).to_tuple()
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Disabled"): with self.subTest("JIT Disabled"):
with jax.disable_jit(): with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict) outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
self.assertEqual(len(outputs), len(jitted_outputs)) self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
# jitted function cannot return OrderedDict
with self.assertRaises(TypeError):
model_jitted_return_dict(**prepared_inputs_dict)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_ids", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_naming_convention(self): def test_naming_convention(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model_class_name = model_class.__name__ model_class_name = model_class.__name__
...@@ -218,3 +288,30 @@ class FlaxModelTesterMixin: ...@@ -218,3 +288,30 @@ class FlaxModelTesterMixin:
module_cls = getattr(bert_modeling_flax_module, module_class_name) module_cls = getattr(bert_modeling_flax_module, module_class_name)
self.assertIsNotNone(module_cls) self.assertIsNotNone(module_cls)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment