Unverified Commit 9e795eac authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix bert2bert test (#10063)

parent 31563e05
...@@ -24,15 +24,9 @@ if is_datasets_available(): ...@@ -24,15 +24,9 @@ if is_datasets_available():
class Seq2seqTrainerTester(TestCasePlus): class Seq2seqTrainerTester(TestCasePlus):
@slow @slow
@require_datasets
@require_torch @require_torch
@require_datasets
def test_finetune_bert2bert(self): def test_finetune_bert2bert(self):
"""
Currently fails with:
ImportError: To be able to use this metric, you need to install the following dependencies['absl', 'nltk', 'rouge_score']
"""
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
...@@ -47,8 +41,6 @@ class Seq2seqTrainerTester(TestCasePlus): ...@@ -47,8 +41,6 @@ class Seq2seqTrainerTester(TestCasePlus):
train_dataset = train_dataset.select(range(32)) train_dataset = train_dataset.select(range(32))
val_dataset = val_dataset.select(range(16)) val_dataset = val_dataset.select(range(16))
rouge = datasets.load_metric("rouge")
batch_size = 4 batch_size = 4
def _map_to_encoder_decoder_inputs(batch): def _map_to_encoder_decoder_inputs(batch):
...@@ -78,15 +70,9 @@ class Seq2seqTrainerTester(TestCasePlus): ...@@ -78,15 +70,9 @@ class Seq2seqTrainerTester(TestCasePlus):
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[ accuracy = sum([int(pred_str[i] == label_str[i]) for i in range(len(pred_str))]) / len(pred_str)
"rouge2"
].mid
return { return {"accuracy": accuracy}
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
# map train dataset # map train dataset
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment