"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "b6779d8df310bcac115d9949fcc6c7502b4c9551"
Unverified Commit 529850ae authored by Nathan Raw's avatar Nathan Raw Committed by GitHub
Browse files

Lightning Updates for v0.8.5 (#5798)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 615be03f
import argparse import argparse
import logging import logging
import os import os
import random
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from transformers import ( from transformers import (
AdamW, AdamW,
...@@ -42,14 +39,6 @@ MODEL_MODES = { ...@@ -42,14 +39,6 @@ MODEL_MODES = {
} }
def set_seed(args: argparse.Namespace):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpus > 0:
torch.cuda.manual_seed_all(args.seed)
class BaseTransformer(pl.LightningModule): class BaseTransformer(pl.LightningModule):
def __init__( def __init__(
self, self,
...@@ -63,7 +52,11 @@ class BaseTransformer(pl.LightningModule): ...@@ -63,7 +52,11 @@ class BaseTransformer(pl.LightningModule):
): ):
"""Initialize a model, tokenizer and config.""" """Initialize a model, tokenizer and config."""
super().__init__() super().__init__()
self.hparams = hparams # TODO: move to self.save_hyperparameters() # TODO: move to self.save_hyperparameters()
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading
self.hparams = hparams
self.step_count = 0 self.step_count = 0
self.tfmr_ckpts = {} self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir) self.output_dir = Path(self.hparams.output_dir)
...@@ -114,17 +107,12 @@ class BaseTransformer(pl.LightningModule): ...@@ -114,17 +107,12 @@ class BaseTransformer(pl.LightningModule):
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer self.opt = optimizer
return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): scheduler = get_linear_schedule_with_warmup(
if self.trainer.use_tpu: self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
xm.optimizer_step(optimizer) )
else: scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
optimizer.step() return [optimizer], [scheduler]
optimizer.zero_grad()
self.lr_scheduler.step() # By default, PL will only step every epoch.
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
self.logger.log_metrics(lrs)
def test_step(self, batch, batch_nb): def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb) return self.validation_step(batch, batch_nb)
...@@ -132,26 +120,24 @@ class BaseTransformer(pl.LightningModule): ...@@ -132,26 +120,24 @@ class BaseTransformer(pl.LightningModule):
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
return self.validation_end(outputs) return self.validation_end(outputs)
def train_dataloader(self): def setup(self, step):
train_batch_size = self.hparams.train_batch_size train_batch_size = self.hparams.train_batch_size
dataloader = self.load_dataset("train", train_batch_size) dataloader = self.get_dataloader("train", train_batch_size)
self.train_loader = dataloader
t_total = ( self.total_steps = (
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.n_gpu))) (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps // self.hparams.accumulate_grad_batches
* float(self.hparams.num_train_epochs) * float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
) )
self.lr_scheduler = scheduler
return dataloader def train_dataloader(self):
return self.train_loader
def val_dataloader(self): def val_dataloader(self):
return self.load_dataset("dev", self.hparams.eval_batch_size) return self.get_dataloader("dev", self.hparams.eval_batch_size)
def test_dataloader(self): def test_dataloader(self):
return self.load_dataset("test", self.hparams.eval_batch_size) return self.get_dataloader("test", self.hparams.eval_batch_size)
def _feature_file(self, mode): def _feature_file(self, mode):
return os.path.join( return os.path.join(
...@@ -201,16 +187,16 @@ class BaseTransformer(pl.LightningModule): ...@@ -201,16 +187,16 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument( parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
parser.add_argument("--train_batch_size", default=32, type=int) parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int) parser.add_argument("--eval_batch_size", default=32, type=int)
class LoggingCallback(pl.Callback): class LoggingCallback(pl.Callback):
@rank_zero_only def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
pl_module.logger.log_metrics(lrs)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
rank_zero_info("***** Validation results *****") rank_zero_info("***** Validation results *****")
metrics = trainer.callback_metrics metrics = trainer.callback_metrics
...@@ -219,16 +205,15 @@ class LoggingCallback(pl.Callback): ...@@ -219,16 +205,15 @@ class LoggingCallback(pl.Callback):
if key not in ["log", "progress_bar"]: if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****") rank_zero_info("***** Test results *****")
metrics = trainer.callback_metrics metrics = trainer.callback_metrics
# Log and save results to file # Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer: with open(output_test_results_file, "w") as writer:
for key in sorted(metrics): for key in sorted(metrics):
if key not in ["log", "progress_bar"]: if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key]))) rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key]))) writer.write("{} = {}\n".format(key, str(metrics[key])))
...@@ -251,26 +236,23 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -251,26 +236,23 @@ def add_generic_args(parser, root_dir) -> None:
parser.add_argument( parser.add_argument(
"--fp16_opt_level", "--fp16_opt_level",
type=str, type=str,
default="O1", default="O2",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html", "See details at https://nvidia.github.io/apex/amp.html",
) )
parser.add_argument("--fast_dev_run", action="store_true") parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int, default=0)
parser.add_argument("--gpus", type=int, default=1) parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--n_tpu_cores", type=int, default=0)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument( parser.add_argument(
"--gradient_accumulation_steps", "--gradient_accumulation_steps",
dest="accumulate_grad_batches",
type=int, type=int,
default=1, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.", help="Number of updates steps to accumulate before performing a backward/update pass.",
) )
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
parser.add_argument("--val_check_interval", default=1.0, type=float)
def generic_train( def generic_train(
...@@ -283,10 +265,13 @@ def generic_train( ...@@ -283,10 +265,13 @@ def generic_train(
logging_callback=None, logging_callback=None,
**extra_train_kwargs **extra_train_kwargs
): ):
pl.seed_everything(args.seed)
# init model # init model
set_seed(args)
odir = Path(model.hparams.output_dir) odir = Path(model.hparams.output_dir)
odir.mkdir(exist_ok=True) odir.mkdir(exist_ok=True)
# add custom checkpoints
if checkpoint_callback is None: if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint( checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
...@@ -296,38 +281,25 @@ def generic_train( ...@@ -296,38 +281,25 @@ def generic_train(
train_params = {} train_params = {}
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16: if args.fp16:
train_params["use_amp"] = args.fp16 train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level train_params["amp_level"] = args.fp16_opt_level
if args.n_tpu_cores > 0:
global xm
import torch_xla.core.xla_model as xm
train_params["num_tpu_cores"] = args.n_tpu_cores
train_params["gpus"] = 0
if args.gpus > 1: if args.gpus > 1:
train_params["distributed_backend"] = "ddp" train_params["distributed_backend"] = "ddp"
trainer = pl.Trainer( trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks,
logger=logger, logger=logger,
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.gpus,
max_epochs=args.num_train_epochs,
early_stop_callback=early_stopping_callback,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback, checkpoint_callback=checkpoint_callback,
callbacks=[logging_callback] + extra_callbacks, early_stop_callback=early_stopping_callback,
fast_dev_run=args.fast_dev_run,
val_check_interval=args.val_check_interval,
weights_summary=None,
resume_from_checkpoint=args.resume_from_checkpoint,
**train_params, **train_params,
) )
if args.do_train: if args.do_train:
trainer.fit(model) trainer.fit(model)
trainer.logger.log_hyperparams(args)
trainer.logger.save()
return trainer return trainer
...@@ -5,7 +5,7 @@ psutil ...@@ -5,7 +5,7 @@ psutil
sacrebleu sacrebleu
rouge-score rouge-score
tensorflow_datasets tensorflow_datasets
pytorch-lightning==0.8.1 pytorch-lightning==0.8.5
matplotlib matplotlib
git-python==1.0.3 git-python==1.0.3
faiss faiss
......
...@@ -60,7 +60,7 @@ Summarization Tips: ...@@ -60,7 +60,7 @@ Summarization Tips:
- If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter. - If you want to run experiments on improving the summarization finetuning process, try the XSUM Shared Task (below). It's faster to train than CNNDM because the summaries are shorter.
- For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()` - For CNN/DailyMail, the default `val_max_target_length` and `test_max_target_length` will truncate the ground truth labels, resulting in slightly higher rouge scores. To get accurate rouge scores, you should rerun calculate_rouge on the `{output_dir}/test_generations.txt` file saved by `trainer.test()`
- `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM. - `--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 ` is a reasonable setting for XSUM.
- `wandb` can be used by specifying `--logger wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task. - `wandb` can be used by specifying `--logger_name wandb`. It is useful for reproducibility. Specify the environment variable `WANDB_PROJECT='hf_xsum'` to do the XSUM shared task.
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
...@@ -124,7 +124,7 @@ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr') ...@@ -124,7 +124,7 @@ model = AutoModelForSeq2SeqLM.from_pretrained(f'{output_dir}/best_tfmr')
``` ```
#### XSUM Shared Task #### XSUM Shared Task
Compare XSUM results with others by using `--logger wandb_shared`. This requires `wandb` registration. Compare XSUM results with others by using `--logger_name wandb_shared`. This requires `wandb` registration.
Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier! Here is an example command, but you can do whatever you want. Hopefully this will make debugging and collaboration easier!
```bash ```bash
...@@ -135,7 +135,7 @@ WANDB_PROJECT='hf_xsum' ./finetune.sh \ ...@@ -135,7 +135,7 @@ WANDB_PROJECT='hf_xsum' ./finetune.sh \
--train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \ --train_batch_size 16 --eval_batch_size 16 --freeze_embeds --freeze_encoder \
--num_train_epochs 6 \ --num_train_epochs 6 \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--logger wandb --logger_name wandb
``` ```
You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-) You can see your wandb logs [here](https://app.wandb.ai/sshleifer/hf_xsum?workspace=user-)
......
...@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer): ...@@ -221,8 +221,8 @@ class SummarizationModule(BaseTransformer):
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = ( t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus))) (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.gradient_accumulation_steps // self.hparams.accumulate_grad_batches
* float(self.hparams.num_train_epochs) * float(self.hparams.max_epochs)
) )
scheduler = get_linear_schedule_with_warmup( scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
...@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer): ...@@ -279,7 +279,7 @@ class SummarizationModule(BaseTransformer):
parser.add_argument("--freeze_encoder", action="store_true") parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true") parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False) parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger", type=str, choices=["default", "wandb", "wandb_shared"], default="default") parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
...@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer): ...@@ -288,7 +288,6 @@ class SummarizationModule(BaseTransformer):
) )
parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False)
return parser return parser
...@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule: ...@@ -318,22 +317,24 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
else: else:
model: SummarizationModule = TranslationModule(args) model: SummarizationModule = TranslationModule(args)
dataset = Path(args.data_dir).name
if ( if (
args.logger == "default" args.logger_name == "default"
or args.fast_dev_run or args.fast_dev_run
or str(args.output_dir).startswith("/tmp") or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var") or str(args.output_dir).startswith("/var")
): ):
logger = True # don't pollute wandb logs unnecessarily logger = True # don't pollute wandb logs unnecessarily
elif args.logger == "wandb": elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name) logger = WandbLogger(name=model.output_dir.name, project=dataset)
elif args.logger == "wandb_shared": elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name) logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
trainer: pl.Trainer = generic_train( trainer: pl.Trainer = generic_train(
model, model,
args, args,
...@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule: ...@@ -352,13 +353,17 @@ def main(args, model=None) -> SummarizationModule:
model.hparams.test_checkpoint = checkpoints[-1] model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1] trainer.resume_from_checkpoint = checkpoints[-1]
trainer.logger.log_hyperparams(model.hparams) trainer.logger.log_hyperparams(model.hparams)
trainer.test(model) # this breaks in DDP, known lightning issue. See evaluate_checkpoint to recover metrics.
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model return model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -10,5 +10,4 @@ python finetune.py \ ...@@ -10,5 +10,4 @@ python finetune.py \
--do_predict \ --do_predict \
--n_val 1000 \ --n_val 1000 \
--val_check_interval 0.1 \ --val_check_interval 0.1 \
--sortish_sampler \
$@ $@
...@@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG) ...@@ -26,7 +26,7 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger() logger = logging.getLogger()
CUDA_AVAILABLE = torch.cuda.is_available() CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = { CHEAP_ARGS = {
"logger": "default", "logger_name": "default",
"length_penalty": 0.5, "length_penalty": 0.5,
"cache_dir": "", "cache_dir": "",
"task": "summarization", "task": "summarization",
...@@ -48,7 +48,7 @@ CHEAP_ARGS = { ...@@ -48,7 +48,7 @@ CHEAP_ARGS = {
"max_grad_norm": 1.0, "max_grad_norm": 1.0,
"do_train": True, "do_train": True,
"do_predict": True, "do_predict": True,
"gradient_accumulation_steps": 1, "accumulate_grad_batches": 1,
"server_ip": "", "server_ip": "",
"server_port": "", "server_port": "",
"seed": 42, "seed": 42,
...@@ -60,7 +60,7 @@ CHEAP_ARGS = { ...@@ -60,7 +60,7 @@ CHEAP_ARGS = {
"weight_decay": 0.0, "weight_decay": 0.0,
"adam_epsilon": 1e-08, "adam_epsilon": 1e-08,
"warmup_steps": 0, "warmup_steps": 0,
"num_train_epochs": 1, "max_epochs": 1,
"train_batch_size": 2, "train_batch_size": 2,
"eval_batch_size": 2, "eval_batch_size": 2,
"max_source_length": 12, "max_source_length": 12,
...@@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -122,7 +122,7 @@ class TestSummarizationDistiller(unittest.TestCase):
updates = dict( updates = dict(
student_encoder_layers=2, student_encoder_layers=2,
student_decoder_layers=1, student_decoder_layers=1,
num_train_epochs=4, max_epochs=4,
val_check_interval=0.25, val_check_interval=0.25,
alpha_hid=2.0, alpha_hid=2.0,
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED", model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
...@@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -156,7 +156,7 @@ class TestSummarizationDistiller(unittest.TestCase):
default_updates = dict( default_updates = dict(
train_batch_size=1, train_batch_size=1,
eval_batch_size=2, eval_batch_size=2,
num_train_epochs=2, max_epochs=2,
alpha_mlm=0.2, alpha_mlm=0.2,
alpha_ce=0.8, alpha_ce=0.8,
do_predict=True, do_predict=True,
...@@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01) self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"]) self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float) self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1) desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals) self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["test"]), 1) self.assertEqual(len(metrics["test"]), 1)
return model return model
......
...@@ -17,5 +17,5 @@ python finetune.py \ ...@@ -17,5 +17,5 @@ python finetune.py \
--model_name_or_path facebook/mbart-large-cc25 \ --model_name_or_path facebook/mbart-large-cc25 \
--task translation \ --task translation \
--warmup_steps 500 \ --warmup_steps 500 \
--logger wandb --sortish_sampler \ --logger_name wandb --sortish_sampler \
$@ $@
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment