Unverified Commit faacd747 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

[Flax] Add FlaxBlenderbot (#13633)



* Init Flax implementation for Blenderbot

* Add a majority of stuff except for tests

* make style quality

* Add tests and fix some bugs

* Add tests

* Clean source code and fix some bugs

* Fix copies and docs

* Fix jax device condition for tests

* Fix layer norm in the encoder

* Fix a few typos in the test file

* make fix-copies

* make fix-copies

* fix layer norm

* Fix Flax params dtype (#13090)

* Fix PR reference (#13098)

* make fix-copies

* Update tests/test_modeling_flax_blenderbot.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 254fef67
......@@ -385,7 +385,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BigBirdPegasus | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Blenderbot | ✅ | ✅ | ✅ | ✅ | |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -125,3 +125,17 @@ TFBlenderbotForConditionalGeneration
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
:members: call
FlaxBlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBlenderbotModel
:members: __call__, encode, decode
FlaxBlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBlenderbotForConditionalGeneration
:members: __call__, encode, decode
......@@ -1950,6 +1950,9 @@ if is_flax_available():
"FlaxBigBirdPreTrainedModel",
]
)
_import_structure["models.blenderbot"].extend(
["FlaxBlenderbotForConditionalGeneration", "FlaxBlenderbotModel", "FlaxBlenderbotPreTrainedModel"]
)
_import_structure["models.clip"].extend(
[
"FlaxCLIPModel",
......@@ -3647,6 +3650,11 @@ if TYPE_CHECKING:
FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel,
)
from .models.blenderbot import (
FlaxBlenderbotForConditionalGeneration,
FlaxBlenderbotModel,
FlaxBlenderbotPreTrainedModel,
)
from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
......
......@@ -46,6 +46,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("mt5", "FlaxMT5Model"),
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
("blenderbot", "FlaxBlenderbotModel"),
]
)
......@@ -89,6 +90,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
]
)
......
......@@ -418,7 +418,7 @@ class FlaxBartEncoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
......@@ -430,7 +430,7 @@ class FlaxBartEncoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -533,7 +533,7 @@ class FlaxBartDecoderLayer(nn.Module):
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
......@@ -541,7 +541,7 @@ class FlaxBartDecoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
......@@ -550,7 +550,7 @@ class FlaxBartDecoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -730,7 +730,7 @@ class FlaxBartEncoder(nn.Module):
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -802,7 +802,7 @@ class FlaxBartDecoder(nn.Module):
)
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
_import_structure = {
......@@ -47,6 +47,14 @@ if is_tf_available():
]
if is_flax_available():
_import_structure["modeling_flax_blenderbot"] = [
"FlaxBlenderbotForConditionalGeneration",
"FlaxBlenderbotModel",
"FlaxBlenderbotPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig
from .tokenization_blenderbot import BlenderbotTokenizer
......@@ -70,6 +78,13 @@ if TYPE_CHECKING:
TFBlenderbotPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_blenderbot import (
FlaxBlenderbotForConditionalGeneration,
FlaxBlenderbotModel,
FlaxBlenderbotPreTrainedModel,
)
else:
import sys
......
This diff is collapsed.
......@@ -423,7 +423,7 @@ class FlaxMarianEncoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
......@@ -435,7 +435,7 @@ class FlaxMarianEncoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -541,7 +541,7 @@ class FlaxMarianDecoderLayer(nn.Module):
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxMarianAttention(
config=self.config,
embed_dim=self.embed_dim,
......@@ -549,7 +549,7 @@ class FlaxMarianDecoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
......@@ -558,7 +558,7 @@ class FlaxMarianDecoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......
......@@ -429,7 +429,7 @@ class FlaxMBartEncoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
......@@ -441,7 +441,7 @@ class FlaxMBartEncoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -545,7 +545,7 @@ class FlaxMBartDecoderLayer(nn.Module):
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxMBartAttention(
config=self.config,
embed_dim=self.embed_dim,
......@@ -553,7 +553,7 @@ class FlaxMBartDecoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
......@@ -562,7 +562,7 @@ class FlaxMBartDecoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -745,8 +745,8 @@ class FlaxMBartEncoder(nn.Module):
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxMBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -821,8 +821,8 @@ class FlaxMBartDecoder(nn.Module):
)
self.layers = FlaxMBartDecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......
......@@ -423,7 +423,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
......@@ -435,7 +435,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -540,7 +540,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxPegasusAttention(
config=self.config,
embed_dim=self.embed_dim,
......@@ -548,7 +548,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
dropout=self.config.attention_dropout,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
......@@ -557,7 +557,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -705,7 +705,7 @@ class FlaxPegasusEncoder(nn.Module):
self.config.max_position_embeddings, embed_dim, dtype=self.dtype
)
self.layers = FlaxPegasusEncoderLayerCollection(self.config, self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......@@ -775,7 +775,7 @@ class FlaxPegasusDecoder(nn.Module):
)
self.layers = FlaxPegasusDecoderLayerCollection(self.config, self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype)
self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
def __call__(
self,
......
......@@ -477,6 +477,13 @@ else:
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax
jax_device = jax.default_backend()
else:
jax_device = None
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
......
......@@ -639,6 +639,42 @@ class FlaxBigBirdPreTrainedModel:
requires_backends(self, ["flax"])
class FlaxBlenderbotForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBlenderbotModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBlenderbotPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import timeout_decorator # noqa
from transformers import BlenderbotConfig, is_flax_available
from transformers.testing_utils import jax_device, require_flax, slow
from .test_generation_flax_utils import FlaxGenerationTesterMixin
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import os
# The slow tests are often failing with OOM error on GPU
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
import jax.numpy as jnp
from transformers import BlenderbotTokenizer
from transformers.models.blenderbot.modeling_flax_blenderbot import (
FlaxBlenderbotForConditionalGeneration,
FlaxBlenderbotModel,
shift_tokens_right,
)
def prepare_blenderbot_inputs_dict(
config,
input_ids,
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
if decoder_attention_mask is None:
decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0)
if head_mask is None:
head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
}
class FlaxBlenderbotModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=32,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.initializer_range = initializer_range
def prepare_config_and_inputs(self):
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
decoder_input_ids = shift_tokens_right(input_ids, 1, 2)
config = BlenderbotConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
initializer_range=self.initializer_range,
use_cache=False,
)
inputs_dict = prepare_blenderbot_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=outputs_cache.past_key_values,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
decoder_attention_mask_cache = jnp.concatenate(
[
decoder_attention_mask,
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
],
axis=-1,
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask_cache,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
past_key_values=outputs_cache.past_key_values,
decoder_attention_mask=decoder_attention_mask_cache,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class BlenderbotHeadTests(unittest.TestCase):
vocab_size = 99
def _get_config_and_data(self):
input_ids = np.array(
[
[71, 82, 18, 33, 46, 91, 2],
[68, 34, 26, 58, 30, 82, 2],
[5, 97, 17, 39, 94, 40, 2],
[76, 83, 94, 25, 70, 78, 2],
[87, 59, 41, 35, 48, 66, 2],
[55, 13, 16, 58, 5, 2, 1], # note padding
[64, 27, 31, 51, 12, 75, 2],
[52, 64, 86, 17, 83, 39, 2],
[48, 61, 9, 24, 71, 82, 2],
[26, 1, 60, 48, 22, 13, 2],
[21, 5, 62, 28, 14, 76, 2],
[45, 98, 37, 86, 59, 48, 2],
[70, 70, 50, 9, 28, 0, 2],
],
dtype=np.int64,
)
batch_size = input_ids.shape[0]
config = BlenderbotConfig(
vocab_size=self.vocab_size,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
return config, input_ids, batch_size
# @timeout_decorator.timeout(1) # not working with the decorator so far
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
lm_model = FlaxBlenderbotForConditionalGeneration(config)
outputs = lm_model(input_ids=input_ids)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_lm_uneven_forward(self):
config = BlenderbotConfig(
vocab_size=self.vocab_size,
d_model=14,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=8,
decoder_ffn_dim=8,
max_position_embeddings=48,
)
lm_model = FlaxBlenderbotForConditionalGeneration(config)
context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64)
summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64)
outputs = lm_model(input_ids=context, decoder_input_ids=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_shift_tokens_right(self):
input_ids = np.array([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=np.int64)
shifted = shift_tokens_right(input_ids, 1, 2)
n_pad_before = np.equal(input_ids, 1).astype(np.float32).sum()
n_pad_after = np.equal(shifted, 1).astype(np.float32).sum()
self.assertEqual(shifted.shape, input_ids.shape)
self.assertEqual(n_pad_after, n_pad_before - 1)
self.assertTrue(np.equal(shifted[:, 0], 2).all())
@require_flax
class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
is_encoder_decoder = True
all_model_classes = (
(
FlaxBlenderbotModel,
FlaxBlenderbotForConditionalGeneration,
)
if is_flax_available()
else ()
)
all_generative_model_classes = (FlaxBlenderbotForConditionalGeneration,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxBlenderbotModelTester(self)
def test_use_cache_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
def test_use_cache_forward_with_attn_mask(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
def test_encode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def encode_jitted(input_ids, attention_mask=None, **kwargs):
return model.encode(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_decode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
prepared_inputs_dict = {
"decoder_input_ids": inputs_dict["decoder_input_ids"],
"decoder_attention_mask": inputs_dict["decoder_attention_mask"],
"encoder_outputs": encoder_outputs,
}
@jax.jit
def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
return model.decode(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
)
with self.subTest("JIT Enabled"):
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("facebook/blenderbot-400M-distill")
# FlaxBlenderbotForSequenceClassification expects eos token in input_ids
input_ids = np.ones((1, 1)) * model.config.eos_token_id
outputs = model(input_ids)
self.assertIsNotNone(outputs)
@unittest.skipUnless(jax_device != "cpu", "3B test too slow on CPU.")
@slow
def test_generation_from_short_input_same_as_parlai_3B(self):
FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25)
TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-3B", from_pt=True)
tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B")
src_text = ["Sam"]
model_inputs = tokenizer(src_text, return_tensors="jax")
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
generated_txt = tokenizer.batch_decode(generated_utterances, **TOK_DECODE_KW)
assert generated_txt[0].strip() == tgt_text
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment