Unverified Commit cd9274d0 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[FlaxBert] Add ForCausalLM (#16995)

* [FlaxBert] Add ForCausalLM

* make style

* fix output attentions

* Add RobertaForCausalLM

* remove comment

* fix fx-to-pt model loading

* remove comment

* add modeling tests

* add enc-dec model tests

* add big_bird

* add electra

* make style

* make repo-consitency

* add to docs

* remove roberta test

* quality

* amend cookiecutter

* fix attention_mask bug in flax bert model tester

* tighten pt-fx thresholds to 1e-5

* add 'copied from' statements

* amend 'copied from' statements

* amend 'copied from' statements

* quality
parent 31616b8d
......@@ -33,6 +33,7 @@ if is_flax_available():
AutoTokenizer,
EncoderDecoderConfig,
FlaxBartForCausalLM,
FlaxBertForCausalLM,
FlaxBertModel,
FlaxEncoderDecoderModel,
FlaxGPT2LMHeadModel,
......@@ -545,6 +546,43 @@ class FlaxBartEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "facebook/bart-base")
@require_flax
class FlaxBertEncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = FlaxBertModel(config)
decoder_model = FlaxBertForCausalLM(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = FlaxBertModelTester(self, batch_size=13)
model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
}
def get_pretrained_model(self):
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
@require_flax
class FlaxEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
......
......@@ -19,11 +19,12 @@ import numpy as np
from transformers import RobertaConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
from transformers.models.roberta.modeling_flax_roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
......@@ -112,6 +113,22 @@ class FlaxRobertaModelTester(unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
def prepare_config_and_inputs_for_decoder(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
encoder_hidden_states,
encoder_attention_mask,
)
@require_flax
class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
......@@ -121,6 +138,7 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxRobertaModel,
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
......
......@@ -22,6 +22,7 @@ from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
from ..bart.test_modeling_flax_bart import FlaxBartStandaloneDecoderModelTester
from ..bert.test_modeling_flax_bert import FlaxBertModelTester
from ..gpt2.test_modeling_flax_gpt2 import FlaxGPT2ModelTester
from ..test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
from ..wav2vec2.test_modeling_flax_wav2vec2 import FlaxWav2Vec2ModelTester
......@@ -34,6 +35,7 @@ if is_flax_available():
from flax.traverse_util import flatten_dict
from transformers import (
FlaxBartForCausalLM,
FlaxBertForCausalLM,
FlaxGPT2LMHeadModel,
FlaxSpeechEncoderDecoderModel,
FlaxWav2Vec2Model,
......@@ -807,3 +809,118 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@require_flax
class FlaxWav2Vec2BertModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model_and_inputs(self):
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
"facebook/wav2vec2-large-lv60", "bert-large-uncased"
)
batch_size = 13
input_values = floats_tensor([batch_size, 512], model.config.encoder.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return model, inputs
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = FlaxWav2Vec2Model(config)
decoder_model = FlaxBertForCausalLM(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = FlaxWav2Vec2ModelTester(self, batch_size=13)
model_tester_decoder = FlaxBertModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(config, inputs, attention_mask) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"inputs": inputs,
"attention_mask": attention_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
}
@slow
def test_flaxwav2vec2bert_pt_flax_equivalence(self):
pt_model = SpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large")
fx_model = FlaxSpeechEncoderDecoderModel.from_pretrained("speech-seq2seq/wav2vec2-2-bert-large", from_pt=True)
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
batch_size = 13
input_values = floats_tensor([batch_size, 512], fx_model.config.encoder.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], fx_model.config.decoder.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs_dict = {
"inputs": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
pt_logits = pt_outputs.logits
pt_outputs = pt_outputs.to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_logits = fx_outputs.logits
fx_outputs = fx_outputs.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits.numpy(), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxSpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_logits_loaded = fx_outputs_loaded.logits
fx_outputs_loaded = fx_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits_loaded, pt_logits.numpy(), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
pt_logits_loaded = pt_outputs_loaded.logits
pt_outputs_loaded = pt_outputs_loaded.to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
......@@ -91,6 +91,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
"SeparableConv1D", # Building part of bigger (tested) model.
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
]
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment