Unverified Commit 566b083e authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

TFMarian, TFMbart, TFPegasus, TFBlenderbot (#7987)



* Start plumbing

* Marian close

* Small stubs for all children

* Fixed bart

* marian working

* pegasus test is good, but failing

* Checkin tests

* More model files

* Subtle marian, pegasus integration test failures

* Works well

* rm print

* boom boom

* Still failing model2doc

* merge master

* Equivalence test failing, all others fixed

* cleanup

* Fix embed_scale

* Cleanup marian pipeline test

* Undo extra changes

* Smaller delta

* Cleanup model testers

* undo delta

* fix tests import structure

* cross test decorator

* Cleaner set_weights

* Respect authorized_unexpected_keys

* No warnings

* No warnings

* style

* Nest tf import

* black

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* functional dropout

* fixup

* Fixup

* style_doc

* embs

* shape list

* delete slow force_token_id_to_be_generated func

* fixup
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 6279072f
...@@ -95,3 +95,12 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` ...@@ -95,3 +95,12 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
.. autoclass:: transformers.BlenderbotForConditionalGeneration .. autoclass:: transformers.BlenderbotForConditionalGeneration
:members: :members:
TFBlenderbotForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate`
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
:members:
...@@ -129,3 +129,9 @@ MarianMTModel ...@@ -129,3 +129,9 @@ MarianMTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianMTModel .. autoclass:: transformers.MarianMTModel
TFMarianMTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMarianMTModel
...@@ -79,4 +79,11 @@ MBartForConditionalGeneration ...@@ -79,4 +79,11 @@ MBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartForConditionalGeneration .. autoclass:: transformers.MBartForConditionalGeneration
:members: forward :members:
TFMBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMBartForConditionalGeneration
:members:
...@@ -95,3 +95,9 @@ PegasusForConditionalGeneration ...@@ -95,3 +95,9 @@ PegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PegasusForConditionalGeneration .. autoclass:: transformers.PegasusForConditionalGeneration
TFPegasusForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFPegasusForConditionalGeneration
...@@ -670,6 +670,7 @@ if is_tf_available(): ...@@ -670,6 +670,7 @@ if is_tf_available():
TFBertModel, TFBertModel,
TFBertPreTrainedModel, TFBertPreTrainedModel,
) )
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
from .modeling_tf_camembert import ( from .modeling_tf_camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
...@@ -750,6 +751,8 @@ if is_tf_available(): ...@@ -750,6 +751,8 @@ if is_tf_available():
TFLxmertPreTrainedModel, TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder, TFLxmertVisualFeatureEncoder,
) )
from .modeling_tf_marian import TFMarianMTModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration
from .modeling_tf_mobilebert import ( from .modeling_tf_mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM, TFMobileBertForMaskedLM,
...@@ -771,6 +774,7 @@ if is_tf_available(): ...@@ -771,6 +774,7 @@ if is_tf_available():
TFOpenAIGPTModel, TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel, TFOpenAIGPTPreTrainedModel,
) )
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
from .modeling_tf_roberta import ( from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
......
...@@ -427,7 +427,6 @@ class DecoderLayer(nn.Module): ...@@ -427,7 +427,6 @@ class DecoderLayer(nn.Module):
output_attentions=False, output_attentions=False,
): ):
residual = x residual = x
if layer_state is None: if layer_state is None:
layer_state = {} layer_state = {}
if self.normalize_before: if self.normalize_before:
...@@ -447,7 +446,7 @@ class DecoderLayer(nn.Module): ...@@ -447,7 +446,7 @@ class DecoderLayer(nn.Module):
if not self.normalize_before: if not self.normalize_before:
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
# Cross attention # Cross-Attention Block
residual = x residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before: if self.normalize_before:
...@@ -628,7 +627,6 @@ class BartDecoder(nn.Module): ...@@ -628,7 +627,6 @@ class BartDecoder(nn.Module):
encoder_hidden_states = encoder_hidden_states.transpose(0, 1) encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if not return_dict: if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
......
...@@ -41,6 +41,10 @@ from .configuration_auto import ( ...@@ -41,6 +41,10 @@ from .configuration_auto import (
XLNetConfig, XLNetConfig,
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
from .configuration_blenderbot import BlenderbotConfig
from .configuration_marian import MarianConfig
from .configuration_mbart import MBartConfig
from .configuration_pegasus import PegasusConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_albert import ( from .modeling_tf_albert import (
...@@ -63,6 +67,7 @@ from .modeling_tf_bert import ( ...@@ -63,6 +67,7 @@ from .modeling_tf_bert import (
TFBertLMHeadModel, TFBertLMHeadModel,
TFBertModel, TFBertModel,
) )
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
from .modeling_tf_camembert import ( from .modeling_tf_camembert import (
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
TFCamembertForMultipleChoice, TFCamembertForMultipleChoice,
...@@ -108,6 +113,8 @@ from .modeling_tf_funnel import ( ...@@ -108,6 +113,8 @@ from .modeling_tf_funnel import (
) )
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
from .modeling_tf_marian import TFMarianMTModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration
from .modeling_tf_mobilebert import ( from .modeling_tf_mobilebert import (
TFMobileBertForMaskedLM, TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice, TFMobileBertForMultipleChoice,
...@@ -118,6 +125,7 @@ from .modeling_tf_mobilebert import ( ...@@ -118,6 +125,7 @@ from .modeling_tf_mobilebert import (
TFMobileBertModel, TFMobileBertModel,
) )
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
from .modeling_tf_roberta import ( from .modeling_tf_roberta import (
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
...@@ -210,6 +218,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( ...@@ -210,6 +218,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(T5Config, TFT5ForConditionalGeneration), (T5Config, TFT5ForConditionalGeneration),
(DistilBertConfig, TFDistilBertForMaskedLM), (DistilBertConfig, TFDistilBertForMaskedLM),
(AlbertConfig, TFAlbertForMaskedLM), (AlbertConfig, TFAlbertForMaskedLM),
(MarianConfig, TFMarianMTModel),
(BartConfig, TFBartForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration),
(CamembertConfig, TFCamembertForMaskedLM), (CamembertConfig, TFCamembertForMaskedLM),
(XLMRobertaConfig, TFXLMRobertaForMaskedLM), (XLMRobertaConfig, TFXLMRobertaForMaskedLM),
...@@ -261,8 +270,16 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -261,8 +270,16 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
] ]
) )
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)] [
(T5Config, TFT5ForConditionalGeneration),
(MarianConfig, TFMarianMTModel),
(MBartConfig, TFMBartForConditionalGeneration),
(PegasusConfig, TFPegasusForConditionalGeneration),
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
(BartConfig, TFBartForConditionalGeneration),
]
) )
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
......
This diff is collapsed.
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF BlenderBot model, ported from the fairseq repo."""
from .configuration_blenderbot import BlenderbotConfig
from .file_utils import add_start_docstrings, is_tf_available
from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .utils import logging
if is_tf_available():
import tensorflow as tf
_CONFIG_FOR_DOC = "BlenderbotConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace(
"inherits from :class:`~transformers.TFPreTrainedModel`",
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
).replace("BartConfig", _CONFIG_FOR_DOC)
logger = logging.get_logger(__name__)
@add_start_docstrings("Blenderbot model for open domain dialogue", START_DOCSTRING)
class TFBlenderbotForConditionalGeneration(TFBartForConditionalGeneration):
config_class = BlenderbotConfig
def adjust_logits_during_generation(self, logits, cur_len, max_length):
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
vocab_range = tf.constant(range(self.config.vocab_size))
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
if cur_len == max_length - 1:
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
return logits
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF Marian model, ported from the fairseq repo."""
from .configuration_marian import MarianConfig
from .file_utils import add_start_docstrings, is_tf_available
from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
from .utils import logging
if is_tf_available():
import tensorflow as tf
_CONFIG_FOR_DOC = "MarianConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace(
"inherits from :class:`~transformers.TFPreTrainedModel`",
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
).replace("BartConfig", _CONFIG_FOR_DOC)
logger = logging.get_logger(__name__)
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
class TFMarianMTModel(TFBartForConditionalGeneration):
authorized_missing_keys = [
r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight",
]
config_class = MarianConfig
def adjust_logits_during_generation(self, logits, cur_len, max_length):
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
vocab_range = tf.constant(range(self.config.vocab_size))
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
if cur_len == max_length - 1:
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
return logits
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF mBART model, originally from fairseq."""
from .configuration_mbart import MBartConfig
from .file_utils import add_start_docstrings
from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
from .utils import logging
_CONFIG_FOR_DOC = "MBartConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace(
"inherits from :class:`~transformers.TFPreTrainedModel`",
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
).replace("BartConfig", _CONFIG_FOR_DOC)
logger = logging.get_logger(__name__)
@add_start_docstrings("mBART (multilingual BART) model for machine translation", START_DOCSTRING)
class TFMBartForConditionalGeneration(TFBartForConditionalGeneration):
config_class = MBartConfig
# All the code is in src/transformers/modeling_tf_bart.py
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF Pegasus model, ported from the fairseq repo."""
from .configuration_pegasus import PegasusConfig
from .file_utils import add_start_docstrings
from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
from .utils import logging
_CONFIG_FOR_DOC = "PegasusConfig"
START_DOCSTRING = BART_START_DOCSTRING.replace(
"inherits from :class:`~transformers.TFPreTrainedModel`",
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
).replace("BartConfig", _CONFIG_FOR_DOC)
logger = logging.get_logger(__name__)
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
authorized_missing_keys = [
r"final_logits_bias",
r"model.encoder.embed_positions.weight",
r"model.decoder.embed_positions.weight",
]
config_class = PegasusConfig
# All the code is in src/transformers/modeling_tf_bart.py
...@@ -325,6 +325,15 @@ class TFBertPreTrainedModel: ...@@ -325,6 +325,15 @@ class TFBertPreTrainedModel:
requires_tf(self) requires_tf(self)
class TFBlenderbotForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -797,6 +806,24 @@ class TFLxmertVisualFeatureEncoder: ...@@ -797,6 +806,24 @@ class TFLxmertVisualFeatureEncoder:
requires_tf(self) requires_tf(self)
class TFMarianMTModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFMBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -922,6 +949,15 @@ class TFOpenAIGPTPreTrainedModel: ...@@ -922,6 +949,15 @@ class TFOpenAIGPTPreTrainedModel:
requires_tf(self) requires_tf(self)
class TFPegasusForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -138,7 +138,7 @@ class MarianIntegrationTest(unittest.TestCase): ...@@ -138,7 +138,7 @@ class MarianIntegrationTest(unittest.TestCase):
) )
self.assertEqual(self.model.device, model_inputs.input_ids.device) self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2 model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
) )
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_words return generated_words
...@@ -244,6 +244,8 @@ class TestMarian_RU_FR(MarianIntegrationTest): ...@@ -244,6 +244,8 @@ class TestMarian_RU_FR(MarianIntegrationTest):
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class TestMarian_MT_EN(MarianIntegrationTest): class TestMarian_MT_EN(MarianIntegrationTest):
"""Cover low resource/high perplexity setting. This breaks without adjust_logits_generation overwritten"""
src = "mt" src = "mt"
tgt = "en" tgt = "en"
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."] src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
import tempfile import tempfile
import unittest import unittest
from transformers import is_tf_available import numpy as np
from transformers import BartConfig, BartTokenizer, is_tf_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow
...@@ -28,12 +30,16 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor ...@@ -28,12 +30,16 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers import BartConfig, TFBartForConditionalGeneration, TFBartModel from transformers import TFBartForConditionalGeneration, TFBartModel
from transformers.tokenization_bart import BartTokenizer from transformers.modeling_tf_bart import TFSinusoidalPositionalEmbedding
@require_tf @require_tf
class ModelTester: class TFBartModelTester:
config_cls = BartConfig
config_updates = {}
hidden_act = "gelu"
def __init__(self, parent): def __init__(self, parent):
self.parent = parent self.parent = parent
self.batch_size = 13 self.batch_size = 13
...@@ -45,14 +51,13 @@ class ModelTester: ...@@ -45,14 +51,13 @@ class ModelTester:
self.num_hidden_layers = 5 self.num_hidden_layers = 5
self.num_attention_heads = 4 self.num_attention_heads = 4
self.intermediate_size = 37 self.intermediate_size = 37
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1 self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20 self.max_position_embeddings = 20
self.eos_token_ids = [2] self.eos_token_ids = [2]
self.pad_token_id = 1 self.pad_token_id = 1
self.bos_token_id = 0 self.bos_token_id = 0
# torch.manual_seed(0)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
...@@ -60,7 +65,7 @@ class ModelTester: ...@@ -60,7 +65,7 @@ class ModelTester:
input_ids = tf.concat([input_ids, eos_tensor], axis=1) input_ids = tf.concat([input_ids, eos_tensor], axis=1)
input_ids = tf.clip_by_value(input_ids, 3, self.vocab_size + 1) input_ids = tf.clip_by_value(input_ids, 3, self.vocab_size + 1)
config = BartConfig( config = self.config_cls(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=self.hidden_size, d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers, encoder_layers=self.num_hidden_layers,
...@@ -76,6 +81,7 @@ class ModelTester: ...@@ -76,6 +81,7 @@ class ModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
decoder_start_token_id=self.pad_token_id, decoder_start_token_id=self.pad_token_id,
**self.config_updates,
) )
inputs_dict = prepare_bart_inputs_dict(config, input_ids) inputs_dict = prepare_bart_inputs_dict(config, input_ids)
return config, inputs_dict return config, inputs_dict
...@@ -101,9 +107,10 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase): ...@@ -101,9 +107,10 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
model_tester_cls = TFBartModelTester
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = self.model_tester_cls(self)
self.config_tester = ConfigTester(self, config_class=BartConfig) self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self): def test_config(self):
...@@ -120,7 +127,7 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase): ...@@ -120,7 +127,7 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = TFBartForConditionalGeneration model_class = self.all_generative_model_classes[0]
input_ids = { input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
...@@ -354,3 +361,29 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase): ...@@ -354,3 +361,29 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]]) expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]])
assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3) assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3)
@require_tf
class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase):
desired_weights = [
[0, 0, 0, 0, 0],
[0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374],
[0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258],
]
def test_positional_emb_cache_logic(self):
input_ids = _long_tensor([[4, 10]])
emb1 = TFSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6)
no_cache = emb1(input_ids, use_cache=False)
yes_cache = emb1(input_ids, use_cache=True)
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete!
np.testing.assert_almost_equal(no_cache[-1].numpy(), yes_cache[0][0].numpy())
def test_positional_emb_weights_against_marian(self):
emb1 = TFSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512)
emb1.build(None)
weights = emb1.embeddings.numpy()
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
for j in range(5):
self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3)
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_tf_bart import TFBartModelTester
from tests.test_modeling_tf_common import TFModelTesterMixin
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
class ModelTester(TFBartModelTester):
config_updates = dict(
normalize_before=True,
static_position_embeddings=True,
do_blenderbot_90_layernorm=True,
normalize_embeddings=True,
)
config_cls = BlenderbotConfig
@require_tf
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True
test_pruning = False
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# inputs_embeds not supported
pass
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_saved_model_with_attentions_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
@is_pt_tf_cross_test
@require_tokenizers
class TFBlenderbot90MIntegrationTests(unittest.TestCase):
src_text = [
"Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
]
model_name = "facebook/blenderbot-90M"
@cached_property
def tokenizer(self):
return BlenderbotSmallTokenizer.from_pretrained(self.model_name)
@cached_property
def model(self):
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
return model
@slow
def test_90_generation_from_long_input(self):
model_inputs = self.tokenizer(self.src_text, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
num_beams=2,
use_cache=True,
)
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)[0]
assert generated_words in (
"i don't know. i just feel like i'm going to throw up. it's not fun.",
"i'm not sure. i just feel like i've been feeling like i have to be in a certain place",
"i'm not sure. i just feel like i've been in a bad situation.",
)
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import warnings
from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available
from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_bart import TFBartModelTester
from .test_modeling_tf_common import TFModelTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFMarianMTModel
class ModelTester(TFBartModelTester):
config_updates = dict(static_position_embeddings=True, add_bias_logits=True)
config_cls = MarianConfig
@require_tf
class TestTFMarianCommon(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True
test_pruning = False
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.config_tester = ConfigTester(self, config_class=MarianConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# inputs_embeds not supported
pass
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_saved_model_with_attentions_output(self):
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pre-trained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
class AbstractMarianIntegrationTest(unittest.TestCase):
maxDiff = 1000 # show more chars for failing integration tests
@classmethod
def setUpClass(cls) -> None:
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
return cls
@cached_property
def tokenizer(self) -> MarianTokenizer:
return AutoTokenizer.from_pretrained(self.model_name)
@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_token_id
@cached_property
def model(self):
warnings.simplefilter("error")
model: TFMarianMTModel = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
assert isinstance(model, TFMarianMTModel)
c = model.config
self.assertListEqual(c.bad_words_ids, [[c.pad_token_id]])
self.assertEqual(c.max_length, 512)
self.assertEqual(c.decoder_start_token_id, c.pad_token_id)
return model
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
generated_words = self.translate_src_text(**tokenizer_kwargs)
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
)
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)
return generated_words
@require_sentencepiece
@require_tokenizers
@is_pt_tf_cross_test
class TestMarian_MT_EN(AbstractMarianIntegrationTest):
"""Cover low resource/high perplexity setting. This breaks if pad_token_id logits not set to LARGE_NEGATIVE."""
src = "mt"
tgt = "en"
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
@slow
def test_batch_generation_mt_en(self):
self._assert_generated_batch_equal_expected()
@is_pt_tf_cross_test
@require_sentencepiece
@require_tokenizers
class TestMarian_en_zh(AbstractMarianIntegrationTest):
src = "en"
tgt = "zh"
src_text = ["My name is Wolfgang and I live in Berlin"]
expected_text = ["我叫沃尔夫冈 我住在柏林"]
@slow
def test_batch_generation_en_zh(self):
self._assert_generated_batch_equal_expected()
@is_pt_tf_cross_test
@require_sentencepiece
@require_tokenizers
class TestMarian_en_ROMANCE(AbstractMarianIntegrationTest):
"""Multilingual on target side."""
src = "en"
tgt = "ROMANCE"
src_text = [
">>fr<< Don't spend so much time watching TV.",
">>pt<< Your message has been sent.",
">>es<< He's two years older than me.",
]
expected_text = [
"Ne passez pas autant de temps à regarder la télé.",
"A sua mensagem foi enviada.",
"Es dos años más viejo que yo.",
]
@slow
def test_batch_generation_en_ROMANCE_multi(self):
self._assert_generated_batch_equal_expected()
@slow
def test_pipeline(self):
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="tf")
output = pipeline(self.src_text)
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_tf_bart import TFBartModelTester
from tests.test_modeling_tf_common import TFModelTesterMixin
from transformers import AutoTokenizer, MBartConfig, is_tf_available
from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFMBartForConditionalGeneration
class ModelTester(TFBartModelTester):
config_updates = dict(normalize_before=True, add_final_layer_norm=True)
config_cls = MBartConfig
@require_tf
class TestTFMBartCommon(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True
test_pruning = False
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.config_tester = ConfigTester(self, config_class=MBartConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# inputs_embeds not supported
pass
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_saved_model_with_attentions_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
@is_pt_tf_cross_test
@require_sentencepiece
@require_tokenizers
class TestMBartEnRO(unittest.TestCase):
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
]
expected_text = [
"Şeful ONU declară că nu există o soluţie militară în Siria",
]
model_name = "facebook/mbart-large-en-ro"
@cached_property
def tokenizer(self):
return AutoTokenizer.from_pretrained(self.model_name)
@cached_property
def model(self):
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
return model
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
generated_words = self.translate_src_text(**tokenizer_kwargs)
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
)
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_words
@slow
def test_batch_generation_en_ro(self):
self._assert_generated_batch_equal_expected()
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
from .test_configuration_common import ConfigTester
from .test_modeling_pegasus import PGE_ARTICLE, XSUM_ENTRY_LONGER
from .test_modeling_tf_bart import TFBartModelTester
from .test_modeling_tf_common import TFModelTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFPegasusForConditionalGeneration
class ModelTester(TFBartModelTester):
config_updates = dict(
normalize_before=True,
static_position_embeddings=True,
)
hidden_act = "relu"
config_cls = PegasusConfig
@require_tf
class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
model_tester_cls = ModelTester
is_encoder_decoder = True
test_pruning = False
def setUp(self):
self.model_tester = self.model_tester_cls(self)
self.config_tester = ConfigTester(self, config_class=PegasusConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# inputs_embeds not supported
pass
def test_saved_model_with_hidden_states_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_saved_model_with_attentions_output(self):
# Should be uncommented during patrick TF refactor
pass
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
@is_pt_tf_cross_test
@require_sentencepiece
@require_tokenizers
class TFPegasusIntegrationTests(unittest.TestCase):
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
expected_text = [
"California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to reduce the risk of wildfires.",
'N-Dubz have revealed they\'re "grateful" to have been nominated for four Mobo Awards.',
] # differs slightly from pytorch, likely due to numerical differences in linear layers
model_name = "google/pegasus-xsum"
@cached_property
def tokenizer(self):
return AutoTokenizer.from_pretrained(self.model_name)
@cached_property
def model(self):
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
return model
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
generated_words = self.translate_src_text(**tokenizer_kwargs)
assert self.expected_text == generated_words
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch(
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
)
generated_ids = self.model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
num_beams=2,
use_cache=True,
)
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)
return generated_words
@slow
def test_batch_generation(self):
self._assert_generated_batch_equal_expected()
...@@ -67,6 +67,7 @@ MODEL_NAME_TO_DOC_FILE = { ...@@ -67,6 +67,7 @@ MODEL_NAME_TO_DOC_FILE = {
"xlm_prophetnet": "xlmprophetnet.rst", "xlm_prophetnet": "xlmprophetnet.rst",
"xlm_roberta": "xlmroberta.rst", "xlm_roberta": "xlmroberta.rst",
"bert_generation": "bertgeneration.rst", "bert_generation": "bertgeneration.rst",
"marian": "marian.rst",
} }
# This is to make sure the transformers module imported is the one in the repo. # This is to make sure the transformers module imported is the one in the repo.
...@@ -148,7 +149,6 @@ def get_model_doc_files(): ...@@ -148,7 +149,6 @@ def get_model_doc_files():
_ignore_modules = [ _ignore_modules = [
"auto", "auto",
"dialogpt", "dialogpt",
"marian",
"retribert", "retribert",
] ]
doc_files = [] doc_files = []
...@@ -245,6 +245,7 @@ def check_models_are_documented(module, doc_file): ...@@ -245,6 +245,7 @@ def check_models_are_documented(module, doc_file):
def _get_model_name(module): def _get_model_name(module):
""" Get the model name for the module defining it.""" """ Get the model name for the module defining it."""
splits = module.__name__.split("_") splits = module.__name__.split("_")
# Secial case for transfo_xl # Secial case for transfo_xl
if splits[-1] == "xl": if splits[-1] == "xl":
return "_".join(splits[-2:]) return "_".join(splits[-2:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment