"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "56ee176c246100cdb1c0ed17338fe89704467e65"
Unverified Commit e9a2f772 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] --eval_max_generate_length (#7018)

parent df4594a9
...@@ -16,5 +16,5 @@ python distillation.py \ ...@@ -16,5 +16,5 @@ python distillation.py \
--train_batch_size=$BS --eval_batch_size=$BS \ --train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \ --tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 --logger_name wandb \ --warmup_steps 500 --logger_name wandb \
--fp16_opt_level O1 --task translation --normalize_hidden \ --fp16_opt_level O1 --task translation --normalize_hidden --num_sanity_val_steps=0 \
"$@" "$@"
...@@ -13,5 +13,5 @@ python distillation.py \ ...@@ -13,5 +13,5 @@ python distillation.py \
--train_batch_size=$BS --eval_batch_size=$BS \ --train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name $m --model_name_or_path $m \ --tokenizer_name $m --model_name_or_path $m \
--warmup_steps 500 --sortish_sampler --logger_name wandb \ --warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level=O1 --task translation \ --gpus 1 --fp16_opt_level=O1 --task translation --num_sanity_val_steps=0 \
"$@" "$@"
...@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple ...@@ -11,7 +11,6 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train from lightning_base import BaseTransformer, add_generic_args, generic_train
...@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer): ...@@ -94,6 +93,9 @@ class SummarizationModule(BaseTransformer):
"val": self.hparams.val_max_target_length, "val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length, "test": self.hparams.test_max_target_length,
} }
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
self.hparams.sortish_sampler = False
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
...@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer): ...@@ -114,6 +116,10 @@ class SummarizationModule(BaseTransformer):
) )
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length
else:
self.eval_max_length = self.model.config.max_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
def freeze_embeds(self): def freeze_embeds(self):
...@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer): ...@@ -209,12 +215,15 @@ class SummarizationModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict: def _generative_step(self, batch: dict) -> dict:
t0 = time.time() t0 = time.time()
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
generated_ids = self.model.generate( generated_ids = self.model.generate(
batch["input_ids"], batch["input_ids"],
attention_mask=batch["attention_mask"], attention_mask=batch["attention_mask"],
use_cache=True, use_cache=True,
decoder_start_token_id=self.decoder_start_token_id, decoder_start_token_id=self.decoder_start_token_id,
num_beams=self.eval_beams, num_beams=self.eval_beams,
max_length=self.eval_max_length,
) )
gen_time = (time.time() - t0) / batch["input_ids"].shape[0] gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids) preds: List[str] = self.ids_to_clean_text(generated_ids)
...@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer): ...@@ -248,7 +257,7 @@ class SummarizationModule(BaseTransformer):
dataset = self.get_dataset(type_path) dataset = self.get_dataset(type_path)
sampler = None sampler = None
if self.hparams.sortish_sampler and type_path == "train": if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
sampler = dataset.make_sortish_sampler(batch_size) sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False shuffle = False
...@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer): ...@@ -321,6 +330,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument( parser.add_argument(
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
) )
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
parser.add_argument( parser.add_argument(
"--early_stopping_patience", "--early_stopping_patience",
...@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule: ...@@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
else: else:
model: SummarizationModule = TranslationModule(args) model: SummarizationModule = TranslationModule(args)
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
warnings.warn("FP16 only seems to work with torch 1.5+apex")
dataset = Path(args.data_dir).name dataset = Path(args.data_dir).name
if ( if (
args.logger_name == "default" args.logger_name == "default"
......
...@@ -34,6 +34,7 @@ CHEAP_ARGS = { ...@@ -34,6 +34,7 @@ CHEAP_ARGS = {
"supervise_forward": True, "supervise_forward": True,
"normalize_hidden": True, "normalize_hidden": True,
"label_smoothing": 0.2, "label_smoothing": 0.2,
"eval_max_gen_length": None,
"eval_beams": 1, "eval_beams": 1,
"val_metric": "loss", "val_metric": "loss",
"save_top_k": 1, "save_top_k": 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment