Unverified Commit 9e68d075 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Seq2SeqTrainer (#6769)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent d9d0f114
......@@ -189,6 +189,34 @@ from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
```
### Fine-tuning using Seq2SeqTrainer
To use `Seq2SeqTrainer` for fine-tuning you should use the `finetune_trainer.py` script. It subclasses `Trainer` to extend it for seq2seq training. Except the `Trainer` releated `TrainingArguments`, it shares the same argument names as that of `finetune.py` file. One notable difference is that, calculating generative metrics (BLEU, ROUGE) is optional and is controlled using the `--predict_with_generate` argument, set this argument to calculate BLEU and ROUGE metrics.
To see all the possible command line options, run:
```bash
./builtin_trainer/finetune.sh --help # This calls python finetune_trainer.py --help
```
**At the moment, `Seq2SeqTrainer` does not support *with teacher* distillation.**
All `Seq2SeqTrainer` based fine-tuning scripts are included in the `builtin_trainer` directory.
#### TPU Training
`Seq2SeqTrainer` supports TPU training with few caveats
1. As `generate` method does not work on TPU at the moment, `predict_with_generate` can not be used. You should use `--prediction_loss_only` to only calculate loss, and do not set `--do_predict` and `--predict_with_generate`.
2. All sequences should be padded to be of equal length otherwise it leads to extremely slow training. (`finetune_trainer.py` does this automatically when running on TPU.)
We provide a very simple launcher script named `xla_spawn.py` that lets you run our example scripts on multiple TPU cores without any boilerplate. Just pass a --num_cores flag to this script, then your regular training script with its arguments (this is similar to the torch.distributed.launch helper for torch.distributed).
`builtin_trainer/finetune_tpu.sh` script provides minimal arguments needed for TPU training.
Following command fine-tunes `sshleifer/student_marian_en_ro_6_3` on TPU V3-8 and should complete one epoch in ~5-6 mins.
```bash
./builtin_trainer/train_distil_marian_enro_tpu.sh
```
### Evaluation Commands
To create summaries for each article in dataset, we use `run_eval.py`, here are a few commands that run eval for different tasks and models.
......
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# run ./builtin_trainer/finetune.sh --help to see all the possible options
python finetune_trainer.py \
--learning_rate=3e-5 \
--fp16 \
--do_train --do_eval --do_predict --evaluate_during_training \
--predict_with_generate \
--n_val 1000 \
"$@"
export TPU_NUM_CORES=8
# the proper usage is documented in the README, you need to specify data_dir, output_dir and model_name_or_path
# run ./builtin_trainer/finetune_tpu.sh --help to see all the possible options
python xla_spawn.py --num_cores $TPU_NUM_CORES \
finetune_trainer.py \
--learning_rate=3e-5 \
--do_train --do_eval --evaluate_during_training \
--prediction_loss_only \
--n_val 1000 \
"$@"
export WANDB_PROJECT=distil-marian
export BS=64
export GAS=1
export m=sshleifer/student_marian_en_ro_6_3
export MAX_LEN=128
python finetune_trainer.py \
--tokenizer_name $m --model_name_or_path $m \
--data_dir $ENRO_DIR \
--output_dir marian_en_ro_6_3 --overwrite_output_dir \
--learning_rate=3e-4 \
--warmup_steps 500 --sortish_sampler \
--fp16 \
--gradient_accumulation_steps=$GAS \
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
--freeze_encoder --freeze_embeds \
--num_train_epochs=6 \
--save_steps 3000 --eval_steps 3000 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--do_train --do_eval --do_predict --evaluate_during_training\
--predict_with_generate --logging_first_step \
--task translation --label_smoothing 0.1 \
--run_name marian_en_ro_6_3 \
"$@"
export WANDB_PROJECT=distil-marian
export BS=64
export m=sshleifer/student_marian_en_ro_6_3
export MAX_LEN=128
export TPU_NUM_CORES=8
python xla_spawn.py --num_cores $TPU_NUM_CORES \
finetune_trainer.py \
--tokenizer_name $m --model_name_or_path $m \
--data_dir $ENRO_DIR \
--output_dir marian_en_ro_6_3 --overwrite_output_dir \
--learning_rate=3e-4 \
--warmup_steps 500 \
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
--freeze_encoder --freeze_embeds \
--num_train_epochs=6 \
--save_steps 500 --eval_steps 500 \
--logging_first_step --logging_steps 200 \
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
--do_train --do_eval --evaluate_during_training \
--prediction_loss_only \
--task translation --label_smoothing 0.1 \
--run_name marian_en_ro_6_3 \
"$@"
export WANDB_PROJECT=distilbart-cnn
export BS=32
export GAS=1
export m=sshleifer/student_cnn_12_6
export tok=facebook/bart-large
export MAX_TGT_LEN=142
python finetune_trainer.py \
--model_name_or_path $m --tokenizer_name $tok \
--data_dir $CNN_DIR \
--output_dir distilbart-cnn-12-6 --overwrite_output_dir \
--learning_rate=3e-5 \
--warmup_steps 500 --sortish_sampler \
--fp16 \
--n_val 500 \
--gradient_accumulation_steps=$GAS \
--per_device_train_batch_size=$BS --per_device_eval_batch_size=$BS \
--freeze_encoder --freeze_embeds \
--num_train_epochs=2 \
--save_steps 3000 --eval_steps 3000 \
--logging_first_step \
--max_target_length $MAX_TGT_LEN --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
--do_train --do_eval --do_predict --evaluate_during_training \
--predict_with_generate \
--run_name distilbart-cnn-12-6 \
"$@"
python finetune_trainer.py \
--model_name_or_path=facebook/mbart-large-cc25 \
--data_dir $ENRO_DIR \
--output_dir mbart_cc25_enro --overwrite_output_dir \
--learning_rate=3e-5 \
--warmup_steps 500 \
--fp16 \
--label_smoothing 0.1 \
--adam_eps 1e-06 \
--src_lang en_XX --tgt_lang ro_RO \
--freeze_embeds \
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
--max_source_length 128 --max_target_length 128 \
--val_max_target_length 128 --test_max_target_length 128 \
--sortish_sampler \
--num_train_epochs 6 \
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
--do_train --do_eval --do_predict --evaluate_during_training \
--predict_with_generate --logging_first_step
--task translation \
--run_name mbart_en_ro \
"$@"
import json
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from seq2seq_trainer import Seq2SeqTrainer
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartTokenizer,
EvalPrediction,
HfArgumentParser,
MBartTokenizer,
T5Tokenizer,
TrainingArguments,
set_seed,
)
from transformers.modeling_bart import shift_tokens_right
from utils import (
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
freeze_params,
lmap,
trim_batch,
use_task_specific_params,
)
logger = logging.getLogger(__name__)
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
labels = labels
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
labels = labels
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
decoder_start_token_id = self.pad_token_id
assert (
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.data_args.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.data_args.tgt_lang,
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
@dataclass
class Seq2SeqTrainingArguments(TrainingArguments):
"""
Parameters:
label_smoothing (:obj:`float`, `optional`, defaults to 0):
The label smoothing epsilon to apply (if not zero).
sortish_sampler (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to SortishSamler or not. It sorts the inputs according to lenghts in-order to minimizing the padding size.
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
"""
label_smoothing: Optional[float] = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."}
)
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."})
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
freeze_encoder: bool = field(default=False, metadata={"help": "Whether tp freeze the encoder."})
freeze_embeds: bool = field(default=False, metadata={"help": "Whether to freeze the embeddings."})
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
data_dir: str = field(
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
)
task: Optional[str] = field(
default="summarization",
metadata={"help": "Task name, summarization (or summarization_{dataset} for pegasus) or translation"},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
val_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
test_max_target_length: Optional[int] = field(
default=142,
metadata={
"help": "The maximum total sequence length for test target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
n_train: Optional[int] = field(default=-1, metadata={"help": "# training examples. -1 means use all."})
n_val: Optional[int] = field(default=-1, metadata={"help": "# validation examples. -1 means use all."})
n_test: Optional[int] = field(default=-1, metadata={"help": "# test examples. -1 means use all."})
src_lang: Optional[str] = field(default=None, metadata={"help": "Source language id for translation."})
tgt_lang: Optional[str] = field(default=None, metadata={"help": "Target language id for translation."})
eval_beams: Optional[int] = field(default=None, metadata={"help": "# num_beams to use for evaluation."})
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
# Set seed
set_seed(training_args.seed)
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
from_tf=".ckpt" in model_args.model_name_or_path,
config=config,
cache_dir=model_args.cache_dir,
)
# use task specific params
use_task_specific_params(model, data_args.task)
# set num_beams for evaluation
if data_args.eval_beams is not None:
model.config.num_beams = data_args.eval_beams
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
# set max length for generation
model.config.max_generate_length = data_args.val_max_target_length
# set decoder_start_token_id for MBart
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
model.config.decoder_start_token_id = decoder_start_token_id
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.mean(lmap(non_pad_len, pred.predictions))
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.mean(lmap(non_pad_len, pred.predictions))
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
def freeze_embeds(model: torch.nn.Module):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try:
freeze_params(model.model.shared)
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
except AttributeError:
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)
if model_args.freeze_embeds:
freeze_embeds(model)
if model_args.freeze_encoder:
freeze_params(model.get_encoder())
assert_all_frozen(model.get_encoder())
dataset_class = Seq2SeqDataset if hasattr(tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
# Get datasets
train_dataset = (
dataset_class(
tokenizer,
type_path="train",
data_dir=data_args.data_dir,
n_obs=data_args.n_train,
max_target_length=data_args.max_target_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
if training_args.do_train
else None
)
eval_dataset = (
dataset_class(
tokenizer,
type_path="val",
data_dir=data_args.data_dir,
n_obs=data_args.n_val,
max_target_length=data_args.val_max_target_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
if training_args.do_eval
else None
)
test_dataset = (
dataset_class(
tokenizer,
type_path="test",
data_dir=data_args.data_dir,
n_obs=data_args.n_test,
max_target_length=data_args.test_max_target_length,
max_source_length=data_args.max_source_length,
prefix=model.config.prefix or "",
)
if training_args.do_predict
else None
)
# Initialize our Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
)
# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
# Evaluation
eval_results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
if trainer.is_world_process_zero():
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
with open(output_eval_file, "w") as f:
json.dump(result, f)
eval_results.update(result)
if training_args.do_predict:
logging.info("*** Test ***")
test_output = trainer.predict(test_dataset=test_dataset)
test_metrics = test_output.metrics
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
if trainer.is_world_process_zero():
logger.info("***** Test results *****")
for key, value in test_metrics.items():
logger.info(" %s = %s", key, value)
with open(output_test_file, "w") as f:
json.dump(test_metrics, f)
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True)
test_preds = lmap(str.strip, test_preds)
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt")
with open(output_test_pred_file, "w") as f:
f.write("\n".join(test_preds))
return eval_results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
import logging
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler
from transformers import Trainer
from transformers.file_utils import is_torch_tpu_available
from transformers.trainer import get_tpu_sampler
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
logger = logging.getLogger(__name__)
class Seq2SeqTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
)
return (
RandomSampler(self.train_dataset)
if self.args.local_rank == -1
else DistributedSampler(self.train_dataset)
)
def compute_loss(self, model, inputs):
labels = inputs.pop("labels")
outputs = model(**inputs, use_cache=False)
logits = outputs[0]
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id)
def _compute_loss(self, logits, labels, ignore_index):
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
assert logits.shape[-1] == self.model.config.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, labels, self.args.label_smoothing, ignore_index=ignore_index
)
return loss
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
A tuple with the loss, logits and labels (each being optional).
"""
inputs = self._prepare_inputs(inputs)
max_length = (
model.config.max_generate_length
if hasattr(model.config, "max_generate_length")
else model.config.max_position_embeddings
)
with torch.no_grad():
if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
use_cache=True,
num_beams=model.config.num_beams,
max_length=max_length,
)
# in case the batch is shorter than max length, the output should be padded
generated_tokens = self._pad_tensors_to_max_len(
generated_tokens, max_length, model.config.pad_token_id
)
labels_out = inputs.get("labels")
outputs = model(**inputs)
logits = outputs[1]
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id)
loss = loss.mean().item()
if self.args.prediction_loss_only:
logits = None
else:
logits = generated_tokens if self.args.predict_with_generate else logits
if self.args.prediction_loss_only:
return (loss, None, None)
labels_out = labels_out.detach()
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id)
return (loss, logits.detach(), labels)
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
padded_tensor = pad_token_id * torch.ones(
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
)
padded_tensor[:, : tensor.shape[-1]] = tensor
return padded_tensor
import os
import sys
import tempfile
from unittest.mock import patch
from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow
from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY
from .utils import load_json
MODEL_NAME = MBART_TINY
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow
def test_model_download():
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL)
@slow
def test_finetune_trainer():
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = tempfile.mkdtemp(prefix="marian_output")
max_len = "128"
num_train_epochs = 4
eval_steps = 2
argv = [
"--model_name_or_path",
MARIAN_MODEL,
"--data_dir",
data_dir,
"--output_dir",
output_dir,
"--overwrite_output_dir",
"--n_train",
"8",
"--n_val",
"8",
"--max_source_length",
max_len,
"--max_target_length",
max_len,
"--val_max_target_length",
max_len,
"--do_train",
"--do_eval",
"--do_predict",
"--num_train_epochs",
str(num_train_epochs),
"--per_device_train_batch_size",
"4",
"--per_device_eval_batch_size",
"4",
"--learning_rate",
"3e-4",
"--warmup_steps",
"8",
"--evaluate_during_training",
"--predict_with_generate",
"--logging_steps",
0,
"--save_steps",
str(eval_steps),
"--eval_steps",
str(eval_steps),
"--sortish_sampler",
"--label_smoothing",
"0.1",
"--task",
"translation",
]
testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs):
main()
# Check metrics
logs = load_json(os.path.join(output_dir, "log_history.json"))
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
assert isinstance(last_step_stats["eval_bleu"], float)
# test if do_predict saves generations and metrics
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.json" in contents
"""
A simple launcher script for TPU training
Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
::
>>> python xla_spawn.py --num_cores=NUM_CORES_YOU_HAVE
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
arguments of your training script)
"""
import importlib
import sys
from argparse import REMAINDER, ArgumentParser
from pathlib import Path
import torch_xla.distributed.xla_multiprocessing as xmp
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description=(
"PyTorch TPU distributed training launch "
"helper utility that will spawn up "
"multiple distributed processes"
)
)
# Optional arguments for the launch helper
parser.add_argument("--num_cores", type=int, default=1, help="Number of TPU cores to use (1 or 8).")
# positional
parser.add_argument(
"training_script",
type=str,
help=(
"The full path to the single TPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script"
),
)
# rest from the training program
parser.add_argument("training_script_args", nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# Import training_script as a module.
script_fpath = Path(args.training_script)
sys.path.append(str(script_fpath.parent.resolve()))
mod_name = script_fpath.stem
mod = importlib.import_module(mod_name)
# Patch sys.argv
sys.argv = [args.training_script] + args.training_script_args + ["--tpu_num_cores", str(args.num_cores)]
xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment