Unverified Commit 72d363d9 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples/s2s] clean up finetune_trainer (#7509)

parent bd262158
...@@ -2,37 +2,29 @@ import logging ...@@ -2,37 +2,29 @@ import logging
import os import os
import sys import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple from typing import Optional
import numpy as np
import torch
from seq2seq_trainer import Seq2SeqTrainer from seq2seq_trainer import Seq2SeqTrainer
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
BartTokenizer,
EvalPrediction,
HfArgumentParser, HfArgumentParser,
MBartTokenizer, MBartTokenizer,
T5Tokenizer,
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.modeling_bart import shift_tokens_right
from transformers.trainer_utils import EvaluationStrategy from transformers.trainer_utils import EvaluationStrategy
from utils import ( from utils import (
LegacySeq2SeqDataset, LegacySeq2SeqDataset,
Seq2SeqDataCollator,
Seq2SeqDataset, Seq2SeqDataset,
assert_all_frozen, assert_all_frozen,
calculate_bleu, build_compute_metrics_fn,
calculate_rouge,
freeze_embeds, freeze_embeds,
freeze_params, freeze_params,
lmap, lmap,
save_json, save_json,
trim_batch,
use_task_specific_params, use_task_specific_params,
write_txt_file, write_txt_file,
) )
...@@ -41,66 +33,6 @@ from utils import ( ...@@ -41,66 +33,6 @@ from utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.data_args.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.data_args.tgt_lang,
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
@dataclass @dataclass
class Seq2SeqTrainingArguments(TrainingArguments): class Seq2SeqTrainingArguments(TrainingArguments):
""" """
...@@ -271,34 +203,6 @@ def main(): ...@@ -271,34 +203,6 @@ def main():
), "mBart requires --tgt_lang and --src_lang" ), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
if model_args.freeze_embeds: if model_args.freeze_embeds:
freeze_embeds(model) freeze_embeds(model)
if model_args.freeze_encoder: if model_args.freeze_encoder:
...@@ -349,13 +253,17 @@ def main(): ...@@ -349,13 +253,17 @@ def main():
) )
# Initialize our Trainer # Initialize our Trainer
compute_metrics_fn = (
build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None
)
trainer = Seq2SeqTrainer( trainer = Seq2SeqTrainer(
model=model, model=model,
config=config,
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None, compute_metrics=compute_metrics_fn,
data_args=data_args, data_args=data_args,
) )
......
...@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__) ...@@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
class Seq2SeqTrainer(Trainer): class Seq2SeqTrainer(Trainer):
def __init__(self, data_args, *args, **kwargs): def __init__(self, config, data_args, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.config = config
self.data_args = data_args self.data_args = data_args
self.max_gen_length = data_args.val_max_target_length self.max_gen_length = data_args.val_max_target_length
self.pad_token_id = self.model.config.pad_token_id self.pad_token_id = self.config.pad_token_id
self.vocab_size = self.config.vocab_size
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
...@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
if self.args.label_smoothing == 0: if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py # Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
assert logits.shape[-1] == self.model.config.vocab_size assert logits.shape[-1] == self.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else: else:
lprobs = torch.nn.functional.log_softmax(logits, dim=-1) lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
......
...@@ -7,7 +7,7 @@ import pickle ...@@ -7,7 +7,7 @@ import pickle
import socket import socket
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, List, Union from typing import Callable, Dict, Iterable, List, Tuple, Union
import git import git
import numpy as np import numpy as np
...@@ -19,8 +19,9 @@ from torch import nn ...@@ -19,8 +19,9 @@ from torch import nn
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.modeling_bart import shift_tokens_right
try: try:
...@@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict: ...@@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)} return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
def trim_batch( def trim_batch(
input_ids, input_ids,
pad_token_id, pad_token_id,
...@@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
return batch_encoding return batch_encoding
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.data_args.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.data_args.tgt_lang,
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
class SortishSampler(Sampler): class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo." "Go through the text data by order of src length with a bit of randomness. From fastai repo."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment