Unverified Commit 48f23f92 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2sTrainer] test + code cleanup (#7467)

parent 097049b8
...@@ -26,6 +26,7 @@ from utils import ( ...@@ -26,6 +26,7 @@ from utils import (
calculate_bleu, calculate_bleu,
calculate_rouge, calculate_rouge,
flatten_list, flatten_list,
freeze_embeds,
freeze_params, freeze_params,
get_git_info, get_git_info,
label_smoothed_nll_loss, label_smoothed_nll_loss,
...@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer): ...@@ -90,7 +91,7 @@ class SummarizationModule(BaseTransformer):
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
if self.hparams.freeze_embeds: if self.hparams.freeze_embeds:
self.freeze_embeds() freeze_embeds(self.model)
if self.hparams.freeze_encoder: if self.hparams.freeze_encoder:
freeze_params(self.model.get_encoder()) freeze_params(self.model.get_encoder())
assert_all_frozen(self.model.get_encoder()) assert_all_frozen(self.model.get_encoder())
...@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer): ...@@ -105,29 +106,12 @@ class SummarizationModule(BaseTransformer):
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
) )
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
if self.hparams.eval_max_gen_length is not None: if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length self.eval_max_length = self.hparams.eval_max_gen_length
else: else:
self.eval_max_length = self.model.config.max_length self.eval_max_length = self.model.config.max_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
if self.model_type == "t5":
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
elif self.model_type == "fsmt":
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
def forward(self, input_ids, **kwargs): def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs) return self.model(input_ids, **kwargs)
......
import json
import logging import logging
import os import os
import sys import sys
...@@ -29,10 +28,13 @@ from utils import ( ...@@ -29,10 +28,13 @@ from utils import (
assert_all_frozen, assert_all_frozen,
calculate_bleu, calculate_bleu,
calculate_rouge, calculate_rouge,
freeze_embeds,
freeze_params, freeze_params,
lmap, lmap,
save_json,
trim_batch, trim_batch,
use_task_specific_params, use_task_specific_params,
write_txt_file,
) )
...@@ -43,6 +45,7 @@ class Seq2SeqDataCollator: ...@@ -43,6 +45,7 @@ class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None): def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id self.pad_token_id = tokenizer.pad_token_id
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
self.data_args = data_args self.data_args = data_args
self.tpu_num_cores = tpu_num_cores self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer) self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
...@@ -65,10 +68,8 @@ class Seq2SeqDataCollator: ...@@ -65,10 +68,8 @@ class Seq2SeqDataCollator:
if isinstance(self.tokenizer, T5Tokenizer): if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels) decoder_input_ids = self._shift_right_t5(labels)
labels = labels
else: else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id) decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
labels = labels
batch = { batch = {
"input_ids": input_ids, "input_ids": input_ids,
...@@ -79,17 +80,10 @@ class Seq2SeqDataCollator: ...@@ -79,17 +80,10 @@ class Seq2SeqDataCollator:
return batch return batch
def _shift_right_t5(self, input_ids): def _shift_right_t5(self, input_ids):
decoder_start_token_id = self.pad_token_id
assert (
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right # shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]: def _encode(self, batch) -> Dict[str, torch.Tensor]:
...@@ -267,17 +261,15 @@ def main(): ...@@ -267,17 +261,15 @@ def main():
use_task_specific_params(model, data_args.task) use_task_specific_params(model, data_args.task)
# set num_beams for evaluation # set num_beams for evaluation
if data_args.eval_beams is not None: if data_args.eval_beams is None:
model.config.num_beams = data_args.eval_beams data_args.eval_beams = model.config.num_beams
assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1"
# set max length for generation
model.config.max_generate_length = data_args.val_max_target_length
# set decoder_start_token_id for MBart # set decoder_start_token_id for MBart
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer): if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] assert (
model.config.decoder_start_token_id = decoder_start_token_id data_args.tgt_lang is not None and data_args.src_lang is not None
), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int: def non_pad_len(tokens: np.ndarray) -> int:
...@@ -293,32 +285,20 @@ def main(): ...@@ -293,32 +285,20 @@ def main():
def summarization_metrics(pred: EvalPrediction) -> Dict: def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred) pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str) rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.mean(lmap(non_pad_len, pred.predictions)) summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len}) rouge.update({"gen_len": summ_len})
return rouge return rouge
def translation_metrics(pred: EvalPrediction) -> Dict: def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred) pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str) bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.mean(lmap(non_pad_len, pred.predictions)) gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len}) bleu.update({"gen_len": gen_len})
return bleu return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn return compute_metrics_fn
def freeze_embeds(model: torch.nn.Module):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
try:
freeze_params(model.model.shared)
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
except AttributeError:
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)
if model_args.freeze_embeds: if model_args.freeze_embeds:
freeze_embeds(model) freeze_embeds(model)
if model_args.freeze_encoder: if model_args.freeze_encoder:
...@@ -376,6 +356,7 @@ def main(): ...@@ -376,6 +356,7 @@ def main():
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None, compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
data_args=data_args,
) )
# Training # Training
...@@ -396,41 +377,36 @@ def main(): ...@@ -396,41 +377,36 @@ def main():
result = trainer.evaluate() result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.json")
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
logger.info("***** Eval results *****") logger.info("***** Eval results *****")
for key, value in result.items(): for key, value in result.items():
logger.info(" %s = %s", key, value) logger.info(" %s = %s", key, value)
save_json(result, os.path.join(training_args.output_dir, "eval_results.json"))
with open(output_eval_file, "w") as f:
json.dump(result, f)
eval_results.update(result) eval_results.update(result)
if training_args.do_predict: if training_args.do_predict:
logging.info("*** Test ***") logging.info("*** Test ***")
test_output = trainer.predict(test_dataset=test_dataset) test_output = trainer.predict(test_dataset=test_dataset)
test_metrics = test_output.metrics test_metrics = {k.replace("eval", "test"): v for k, v in test_output.metrics.items()}
test_metrics = {k.replace("eval", "test"): v for k, v in test_metrics.items()}
output_test_file = os.path.join(training_args.output_dir, "test_results.json")
if trainer.is_world_process_zero(): if trainer.is_world_process_zero():
logger.info("***** Test results *****") logger.info("***** Test results *****")
for key, value in test_metrics.items(): for key, value in test_metrics.items():
logger.info(" %s = %s", key, value) logger.info(" %s = %s", key, value)
with open(output_test_file, "w") as f: save_json(test_metrics, os.path.join(training_args.output_dir, "test_results.json"))
json.dump(test_metrics, f) eval_results.update(test_metrics)
if training_args.predict_with_generate: if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True) test_preds = tokenizer.batch_decode(
test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
test_preds = lmap(str.strip, test_preds) test_preds = lmap(str.strip, test_preds)
output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt") write_txt_file(test_preds, os.path.join(training_args.output_dir, "test_generations.txt"))
with open(output_test_pred_file, "w") as f:
f.write("\n".join(test_preds))
if trainer.is_world_process_zero():
save_json(eval_results, "all_results.json")
return eval_results return eval_results
......
...@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__) ...@@ -20,6 +20,12 @@ logger = logging.getLogger(__name__)
class Seq2SeqTrainer(Trainer): class Seq2SeqTrainer(Trainer):
def __init__(self, data_args, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_args = data_args
self.max_gen_length = data_args.val_max_target_length
self.pad_token_id = self.model.config.pad_token_id
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
return None return None
...@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -41,7 +47,7 @@ class Seq2SeqTrainer(Trainer):
labels = inputs.pop("labels") labels = inputs.pop("labels")
outputs = model(**inputs, use_cache=False) outputs = model(**inputs, use_cache=False)
logits = outputs[0] logits = outputs[0]
return self._compute_loss(logits, labels, ignore_index=model.config.pad_token_id) return self._compute_loss(logits, labels, ignore_index=self.pad_token_id)
def _compute_loss(self, logits, labels, ignore_index): def _compute_loss(self, logits, labels, ignore_index):
if self.args.label_smoothing == 0: if self.args.label_smoothing == 0:
...@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer): ...@@ -81,41 +87,32 @@ class Seq2SeqTrainer(Trainer):
""" """
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
max_length = (
model.config.max_generate_length
if hasattr(model.config, "max_generate_length")
else model.config.max_position_embeddings
)
with torch.no_grad(): with torch.no_grad():
if self.args.predict_with_generate and not self.args.prediction_loss_only: if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate( generated_tokens = model.generate(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
use_cache=True, use_cache=True,
num_beams=model.config.num_beams, num_beams=self.data_args.eval_beams,
max_length=max_length, max_length=self.max_gen_length,
) )
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
generated_tokens = self._pad_tensors_to_max_len( generated_tokens = self._pad_tensors_to_max_len(
generated_tokens, max_length, model.config.pad_token_id generated_tokens, self.max_gen_length, self.pad_token_id
) )
labels_out = inputs.get("labels") labels_out = inputs.get("labels")
outputs = model(**inputs) # Call forward again to get loss # TODO: avoidable?
logits = outputs[1] outputs = model(**inputs, use_cache=False)
loss = self._compute_loss(logits, labels_out, model.config.pad_token_id) loss = self._compute_loss(outputs[1], labels_out, self.pad_token_id)
loss = loss.mean().item() loss = loss.mean().item()
if self.args.prediction_loss_only: if self.args.prediction_loss_only:
logits = None return (loss, None, None)
else:
logits = generated_tokens if self.args.predict_with_generate else logits
if self.args.prediction_loss_only: logits = generated_tokens if self.args.predict_with_generate else outputs[1]
return (loss, None, None)
labels_out = labels_out.detach() labels_out = labels_out.detach()
labels = self._pad_tensors_to_max_len(labels_out, max_length, model.config.pad_token_id) labels = self._pad_tensors_to_max_len(labels_out, self.max_gen_length, self.pad_token_id)
return (loss, logits.detach(), labels) return (loss, logits.detach(), labels)
def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id): def _pad_tensors_to_max_len(self, tensor, max_length, pad_token_id):
......
...@@ -3,36 +3,54 @@ import sys ...@@ -3,36 +3,54 @@ import sys
import tempfile import tempfile
from unittest.mock import patch from unittest.mock import patch
from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import slow
from transformers.trainer_utils import set_seed
from .finetune_trainer import main from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY from .test_seq2seq_examples import MBART_TINY
from .utils import load_json from .utils import load_json
MODEL_NAME = MBART_TINY set_seed(42)
# TODO(SS): MODEL_NAME = "sshleifer/student_mbart_en_ro_1_1"
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow def test_finetune_trainer():
def test_model_download(): output_dir = run_trainer(1, "12", MBART_TINY, 1)
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" logs = load_json(os.path.join(output_dir, "log_history.json"))
BartForConditionalGeneration.from_pretrained(MODEL_NAME) eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
MarianMTModel.from_pretrained(MARIAN_MODEL) first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats
@slow @slow
def test_finetune_trainer(): def test_finetune_trainer_slow():
# TODO(SS): This will fail on devices with more than 1 GPU.
# There is a missing call to __init__process_group somewhere
output_dir = run_trainer(eval_steps=2, max_len="32", model_name=MARIAN_MODEL, num_train_epochs=3)
# Check metrics
logs = load_json(os.path.join(output_dir, "log_history.json"))
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
assert isinstance(last_step_stats["eval_bleu"], float)
# test if do_predict saves generations and metrics
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.json" in contents
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = tempfile.mkdtemp(prefix="marian_output") output_dir = tempfile.mkdtemp(prefix="test_output")
max_len = "128"
num_train_epochs = 4
eval_steps = 2
argv = [ argv = [
"--model_name_or_path", "--model_name_or_path",
MARIAN_MODEL, model_name,
"--data_dir", "--data_dir",
data_dir, data_dir,
"--output_dir", "--output_dir",
...@@ -72,25 +90,17 @@ def test_finetune_trainer(): ...@@ -72,25 +90,17 @@ def test_finetune_trainer():
"--sortish_sampler", "--sortish_sampler",
"--label_smoothing", "--label_smoothing",
"0.1", "0.1",
# "--eval_beams",
# "2",
"--task", "--task",
"translation", "translation",
"--tgt_lang",
"ro_RO",
"--src_lang",
"en_XX",
] ]
testargs = ["finetune_trainer.py"] + argv testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
main() main()
# Check metrics return output_dir
logs = load_json(os.path.join(output_dir, "log_history.json"))
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
last_step_stats = eval_metrics[-1]
assert first_step_stats["eval_bleu"] < last_step_stats["eval_bleu"] # model learned nothing
assert isinstance(last_step_stats["eval_bleu"], float)
# test if do_predict saves generations and metrics
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_results.json" in contents
...@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module): ...@@ -441,6 +441,25 @@ def freeze_params(model: nn.Module):
par.requires_grad = False par.requires_grad = False
def freeze_embeds(model):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type = model.config.model_type
if model_type == "t5":
freeze_params(model.shared)
for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens)
elif model_type == "fsmt":
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
else:
freeze_params(model.model.shared)
for d in [model.model.encoder, model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
def grad_status(model: nn.Module) -> Iterable: def grad_status(model: nn.Module) -> Iterable:
return (par.requires_grad for par in model.parameters()) return (par.requires_grad for par in model.parameters())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment