"vscode:/vscode.git/clone" did not exist on "72983303c59bafecd4a7204850f275ca25170df3"
Unverified Commit ca33278f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

FlaxGPT2 (#11556)



* flax gpt2

* combine masks

* handle shared embeds

* add causal LM sample

* style

* add tests

* style

* fix imports, docs, quality

* don't use cache

* add cache

* add cache 1st version

* make use cache work

* start adding test for generation

* finish generation loop compilation

* rewrite test

* finish

* update

* update

* apply sylvains suggestions

* update

* refactor

* fix typo
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent eb3e072a
...@@ -355,7 +355,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -355,7 +355,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ | | OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | | | OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ | | Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -205,6 +205,13 @@ FlaxAutoModel ...@@ -205,6 +205,13 @@ FlaxAutoModel
:members: :members:
FlaxAutoModelForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForCausalLM
:members:
FlaxAutoModelForPreTraining FlaxAutoModelForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -139,3 +139,17 @@ TFSequenceClassifierOutputWithPast ...@@ -139,3 +139,17 @@ TFSequenceClassifierOutputWithPast
.. autoclass:: transformers.modeling_tf_outputs.TFSequenceClassifierOutputWithPast .. autoclass:: transformers.modeling_tf_outputs.TFSequenceClassifierOutputWithPast
:members: :members:
FlaxGPT2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxGPT2Model
:members: __call__
FlaxGPT2LMHeadModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxGPT2LMHeadModel
:members: __call__
...@@ -1409,6 +1409,7 @@ if is_flax_available(): ...@@ -1409,6 +1409,7 @@ if is_flax_available():
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.auto"].extend( _import_structure["models.auto"].extend(
[ [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -1418,6 +1419,7 @@ if is_flax_available(): ...@@ -1418,6 +1419,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
"FlaxAutoModel", "FlaxAutoModel",
"FlaxAutoModelForCausalLM",
"FlaxAutoModelForMaskedLM", "FlaxAutoModelForMaskedLM",
"FlaxAutoModelForMultipleChoice", "FlaxAutoModelForMultipleChoice",
"FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForNextSentencePrediction",
...@@ -1452,6 +1454,7 @@ if is_flax_available(): ...@@ -1452,6 +1454,7 @@ if is_flax_available():
"FlaxElectraPreTrainedModel", "FlaxElectraPreTrainedModel",
] ]
) )
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"])
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
...@@ -2634,6 +2637,7 @@ if TYPE_CHECKING: ...@@ -2634,6 +2637,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_utils import FlaxPreTrainedModel from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import ( from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -2643,6 +2647,7 @@ if TYPE_CHECKING: ...@@ -2643,6 +2647,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
FlaxAutoModel, FlaxAutoModel,
FlaxAutoModelForCausalLM,
FlaxAutoModelForMaskedLM, FlaxAutoModelForMaskedLM,
FlaxAutoModelForMultipleChoice, FlaxAutoModelForMultipleChoice,
FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForNextSentencePrediction,
...@@ -2672,6 +2677,7 @@ if TYPE_CHECKING: ...@@ -2672,6 +2677,7 @@ if TYPE_CHECKING:
FlaxElectraModel, FlaxElectraModel,
FlaxElectraPreTrainedModel, FlaxElectraPreTrainedModel,
) )
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
......
...@@ -1038,6 +1038,20 @@ FLAX_MULTIPLE_CHOICE_SAMPLE = r""" ...@@ -1038,6 +1038,20 @@ FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
>>> logits = outputs.logits >>> logits = outputs.logits
""" """
FLAX_CAUSAL_LM_SAMPLE = r"""
Example::
>>> from transformers import {tokenizer_class}, {model_class}
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
>>> model = {model_class}.from_pretrained('{checkpoint}')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> logits = outputs.logits
"""
FLAX_SAMPLE_DOCSTRINGS = { FLAX_SAMPLE_DOCSTRINGS = {
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE, "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE, "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
...@@ -1045,6 +1059,7 @@ FLAX_SAMPLE_DOCSTRINGS = { ...@@ -1045,6 +1059,7 @@ FLAX_SAMPLE_DOCSTRINGS = {
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE, "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": FLAX_MASKED_LM_SAMPLE, "MaskedLM": FLAX_MASKED_LM_SAMPLE,
"BaseModel": FLAX_BASE_MODEL_SAMPLE, "BaseModel": FLAX_BASE_MODEL_SAMPLE,
"LMHead": FLAX_CAUSAL_LM_SAMPLE,
} }
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Dict, Optional, Tuple
import jaxlib.xla_extension as jax_xla import jaxlib.xla_extension as jax_xla
...@@ -46,6 +46,36 @@ class FlaxBaseModelOutput(ModelOutput): ...@@ -46,6 +46,36 @@ class FlaxBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
class FlaxBaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (:obj:`Dict[str, jax_xla.DeviceArray]`):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
past_key_values: Optional[Dict[str, jax_xla.DeviceArray]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput): class FlaxBaseModelOutputWithPooling(ModelOutput):
""" """
...@@ -103,6 +133,9 @@ class FlaxMaskedLMOutput(ModelOutput): ...@@ -103,6 +133,9 @@ class FlaxMaskedLMOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
FlaxCausalLMOutput = FlaxMaskedLMOutput
@dataclass @dataclass
class FlaxNextSentencePredictorOutput(ModelOutput): class FlaxNextSentencePredictorOutput(ModelOutput):
""" """
......
...@@ -85,6 +85,7 @@ if is_tf_available(): ...@@ -85,6 +85,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_auto"] = [ _import_structure["modeling_flax_auto"] = [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -94,6 +95,7 @@ if is_flax_available(): ...@@ -94,6 +95,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
"FlaxAutoModel", "FlaxAutoModel",
"FlaxAutoModelForCausalLM",
"FlaxAutoModelForMaskedLM", "FlaxAutoModelForMaskedLM",
"FlaxAutoModelForMultipleChoice", "FlaxAutoModelForMultipleChoice",
"FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForNextSentencePrediction",
...@@ -167,6 +169,7 @@ if TYPE_CHECKING: ...@@ -167,6 +169,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_auto import ( from .modeling_flax_auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -176,6 +179,7 @@ if TYPE_CHECKING: ...@@ -176,6 +179,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
FlaxAutoModel, FlaxAutoModel,
FlaxAutoModelForCausalLM,
FlaxAutoModelForMaskedLM, FlaxAutoModelForMaskedLM,
FlaxAutoModelForMultipleChoice, FlaxAutoModelForMultipleChoice,
FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForNextSentencePrediction,
......
...@@ -37,6 +37,7 @@ from ..electra.modeling_flax_electra import ( ...@@ -37,6 +37,7 @@ from ..electra.modeling_flax_electra import (
FlaxElectraForTokenClassification, FlaxElectraForTokenClassification,
FlaxElectraModel, FlaxElectraModel,
) )
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from ..roberta.modeling_flax_roberta import ( from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
...@@ -46,7 +47,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -46,7 +47,7 @@ from ..roberta.modeling_flax_roberta import (
FlaxRobertaModel, FlaxRobertaModel,
) )
from .auto_factory import auto_class_factory from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, ElectraConfig, RobertaConfig from .configuration_auto import BertConfig, ElectraConfig, GPT2Config, RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -57,6 +58,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -57,6 +58,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
# Base model mapping # Base model mapping
(RobertaConfig, FlaxRobertaModel), (RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel), (BertConfig, FlaxBertModel),
(GPT2Config, FlaxGPT2Model),
(ElectraConfig, FlaxElectraModel), (ElectraConfig, FlaxElectraModel),
] ]
) )
...@@ -79,6 +81,13 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -79,6 +81,13 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Causal LM mapping
(GPT2Config, FlaxGPT2LMHeadModel)
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
...@@ -123,6 +132,10 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict( ...@@ -123,6 +132,10 @@ FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING) FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
FlaxAutoModelForCausalLM = auto_class_factory(
"FlaxAutoModelForCausalLM", FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, head_doc="causal language modeling"
)
FlaxAutoModelForPreTraining = auto_class_factory( FlaxAutoModelForPreTraining = auto_class_factory(
"FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining" "FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
) )
......
...@@ -18,7 +18,13 @@ ...@@ -18,7 +18,13 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -51,6 +57,8 @@ if is_tf_available(): ...@@ -51,6 +57,8 @@ if is_tf_available():
"TFGPT2PreTrainedModel", "TFGPT2PreTrainedModel",
] ]
if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
...@@ -81,6 +89,9 @@ if TYPE_CHECKING: ...@@ -81,6 +89,9 @@ if TYPE_CHECKING:
TFGPT2PreTrainedModel, TFGPT2PreTrainedModel,
) )
if is_flax_available():
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
else: else:
import importlib import importlib
import os import os
......
This diff is collapsed.
...@@ -11,6 +11,9 @@ class FlaxPreTrainedModel: ...@@ -11,6 +11,9 @@ class FlaxPreTrainedModel:
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
...@@ -44,6 +47,15 @@ class FlaxAutoModel: ...@@ -44,6 +47,15 @@ class FlaxAutoModel:
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxAutoModelForCausalLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxAutoModelForMaskedLM: class FlaxAutoModelForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
...@@ -248,6 +260,24 @@ class FlaxElectraPreTrainedModel: ...@@ -248,6 +260,24 @@ class FlaxElectraPreTrainedModel:
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxGPT2LMHeadModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxGPT2Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForMaskedLM: class FlaxRobertaForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
...@@ -247,12 +247,8 @@ class FlaxModelTesterMixin: ...@@ -247,12 +247,8 @@ class FlaxModelTesterMixin:
model = model_class(config) model = model_class(config)
@jax.jit @jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None): def model_jitted(input_ids, attention_mask=None, **kwargs):
return model( return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
).to_tuple()
with self.subTest("JIT Enabled"): with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict) jitted_outputs = model_jitted(**prepared_inputs_dict)
...@@ -266,11 +262,11 @@ class FlaxModelTesterMixin: ...@@ -266,11 +262,11 @@ class FlaxModelTesterMixin:
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
@jax.jit @jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None): def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
return model( return model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, **kwargs,
) )
# jitted function cannot return OrderedDict # jitted function cannot return OrderedDict
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
import transformers
from transformers import GPT2Config, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
import jax
import jax.numpy as jnp
from jax import lax
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
if is_torch_available():
import torch
class FlaxGPT2ModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = None
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
def prepare_config_and_inputs(self, gradient_checkpointing=False):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
config = GPT2Config(
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings,
use_cache=False,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
return (config, input_ids, input_mask)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values)
outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20
model = model_class_name(config)
attention_mask_cache = jnp.concatenate(
[attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values)
outputs_cache_next = model(
input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_generation(self, config, input_ids):
prompt_length = 3
model = FlaxGPT2LMHeadModel(config)
max_length = 10
batch_size = 1
prompt_ids = input_ids[:1, :prompt_length]
# put all generation logic into one function
def generate(prompt_ids):
def first_pass(prompt_ids):
logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2]
next_token = jnp.argmax(logits[:, -1:], axis=-1)
return next_token, cache
def greedy_search_cond_fn(state):
cur_len, _, _, _ = state
return ~(cur_len == max_length - 1)
def greedy_search_body_fn(state):
cur_len, sequences, current_token, cache = state
next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len))
next_logits, next_cache = model(current_token, past_key_values=cache)[:2]
next_token = jnp.argmax(next_logits, axis=-1)
return cur_len + 1, next_sequences, next_token, next_cache
# init tensor to be filled with generation result
init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0))
# init past key values for cache
past_key_values = model.init_cache(batch_size, max_length)
# first pass with long prompt
next_token, cache = first_pass(prompt_ids)
# prepare state for generation loop
init_state = (jnp.array(prompt_length), init_sequences, next_token, cache)
# fast generation
_, output_sequences, final_token, _ = lax.while_loop(
greedy_search_cond_fn, greedy_search_body_fn, init_state
)
# append last token
output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1))
return output_sequences
jit_generate = jax.jit(generate)
output_sequences = jit_generate(prompt_ids)
self.parent.assertEqual(output_sequences.shape, (1, max_length))
@require_flax
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxGPT2ModelTester(self)
def test_use_cache_forward(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward(model_class_name, config, input_ids, attention_mask)
def test_use_cache_forward_with_attn_mask(self):
for model_class_name in self.all_model_classes:
config, input_ids, attention_mask = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_forward_with_attn_mask(
model_class_name, config, input_ids, attention_mask
)
def test_use_cache_generation(self):
config, input_ids, _ = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_use_cache_generation(config, input_ids)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded[:, -1], pt_output[:, -1].numpy(), 4e-2)
# overwrite from common since `attention_mask` in combination
# with `causal_mask` behaves slighly differently
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
batch_size, seq_length = pt_inputs["input_ids"].shape
rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
pt_inputs["attention_mask"][batch_idx, :start_index] = 0
pt_inputs["attention_mask"][batch_idx, start_index:] = 1
prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0
prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output[:, -1], pt_output[:, -1].numpy(), 4e-2)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("gpt2", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment