Unverified Commit 84c265ff authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[lightning_base] fix s2s logging, only make train_loader once (#6404)

parent 72add6c9
...@@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule): ...@@ -150,15 +150,20 @@ class BaseTransformer(pl.LightningModule):
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
return self.validation_end(outputs) return self.validation_end(outputs)
def setup(self, step): @property
train_batch_size = self.hparams.train_batch_size def total_steps(self) -> int:
dataloader = self.get_dataloader("train", train_batch_size) """The number of total training steps that will be run. Used for lr scheduler purposes."""
self.train_loader = dataloader num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
self.total_steps = ( effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
(len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus))) dataset_size = len(self.train_loader.dataset)
// self.hparams.accumulate_grad_batches return (dataset_size / effective_batch_size) * self.hparams.max_epochs
* float(self.hparams.max_epochs)
) def setup(self, mode):
if mode == "fit":
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
def get_dataloader(self, type_path, batch_size, shuffle=False):
raise NotImplementedError("You must implement this for your task")
def train_dataloader(self): def train_dataloader(self):
return self.train_loader return self.train_loader
...@@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -304,6 +309,13 @@ def add_generic_args(parser, root_dir) -> None:
help="Number of updates steps to accumulate before performing a backward/update pass.", help="Number of updates steps to accumulate before performing a backward/update pass.",
) )
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
def generic_train( def generic_train(
......
...@@ -10,14 +10,7 @@ from torch import nn ...@@ -10,14 +10,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from lightning_base import generic_train from lightning_base import generic_train
from transformers import ( from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
AdamW,
BartConfig,
BartForConditionalGeneration,
MBartTokenizer,
T5Config,
T5ForConditionalGeneration,
)
try: try:
...@@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -158,24 +151,6 @@ class BartSummarizationDistiller(SummarizationModule):
) )
return loss_ce, s_logits_slct, t_logits_slct return loss_ce, s_logits_slct, t_logits_slct
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
self.opt = optimizer
return [optimizer]
@staticmethod @staticmethod
def add_model_specific_args(parser, root_dir): def add_model_specific_args(parser, root_dir):
SummarizationModule.add_model_specific_args(parser, root_dir) SummarizationModule.add_model_specific_args(parser, root_dir)
......
...@@ -3,7 +3,6 @@ import glob ...@@ -3,7 +3,6 @@ import glob
import logging import logging
import os import os
import time import time
import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -14,7 +13,7 @@ import torch ...@@ -14,7 +13,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration
try: try:
...@@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer): ...@@ -252,17 +251,6 @@ class SummarizationModule(BaseTransformer):
def train_dataloader(self) -> DataLoader: def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader return dataloader
def val_dataloader(self) -> DataLoader: def val_dataloader(self) -> DataLoader:
...@@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer): ...@@ -303,12 +291,6 @@ class SummarizationModule(BaseTransformer):
help="The maximum total input sequence length after tokenization. Sequences longer " help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", "than this will be truncated, sequences shorter will be padded.",
) )
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target",
)
parser.add_argument("--freeze_encoder", action="store_true") parser.add_argument("--freeze_encoder", action="store_true")
parser.add_argument("--freeze_embeds", action="store_true") parser.add_argument("--freeze_embeds", action="store_true")
parser.add_argument("--sortish_sampler", action="store_true", default=False) parser.add_argument("--sortish_sampler", action="store_true", default=False)
......
...@@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer): ...@@ -75,7 +75,7 @@ class GLUETransformer(BaseTransformer):
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
def get_dataloader(self, mode: int, batch_size: int, shuffle: bool = False) -> DataLoader: def get_dataloader(self, mode: str, batch_size: int, shuffle: bool = False) -> DataLoader:
"Load datasets. Called after prepare data." "Load datasets. Called after prepare data."
# We test on dev set to compare to benchmarks without having to submit to GLUE server # We test on dev set to compare to benchmarks without having to submit to GLUE server
...@@ -161,13 +161,6 @@ class GLUETransformer(BaseTransformer): ...@@ -161,13 +161,6 @@ class GLUETransformer(BaseTransformer):
type=int, type=int,
help="The number of GPUs allocated for this, it is by default 0 meaning none", help="The number of GPUs allocated for this, it is by default 0 meaning none",
) )
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
parser.add_argument( parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
......
...@@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer): ...@@ -104,8 +104,7 @@ class NERTransformer(BaseTransformer):
) )
def validation_step(self, batch, batch_nb): def validation_step(self, batch, batch_nb):
"Compute validation" """Compute validation""" ""
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]} inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if self.config.model_type != "distilbert": if self.config.model_type != "distilbert":
inputs["token_type_ids"] = ( inputs["token_type_ids"] = (
...@@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer): ...@@ -191,14 +190,6 @@ class NERTransformer(BaseTransformer):
help="The number of GPUs allocated for this, it is by default 0 meaning none", help="The number of GPUs allocated for this, it is by default 0 meaning none",
) )
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
parser.add_argument( parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
) )
......
...@@ -4,6 +4,7 @@ import unittest ...@@ -4,6 +4,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
import run_ner import run_ner
from transformers.testing_utils import slow
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -12,6 +13,7 @@ logger = logging.getLogger() ...@@ -12,6 +13,7 @@ logger = logging.getLogger()
class ExamplesTests(unittest.TestCase): class ExamplesTests(unittest.TestCase):
@slow
def test_run_ner(self): def test_run_ner(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
...@@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase): ...@@ -31,3 +33,23 @@ class ExamplesTests(unittest.TestCase):
with patch.object(sys, "argv", ["run.py"] + testargs): with patch.object(sys, "argv", ["run.py"] + testargs):
result = run_ner.main() result = run_ner.main()
self.assertLess(result["eval_loss"], 1.5) self.assertLess(result["eval_loss"], 1.5)
def test_run_ner_pl(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = """
--model_name distilbert-base-german-cased
--output_dir ./tests/fixtures/tests_samples/temp_dir
--overwrite_output_dir
--data_dir ./tests/fixtures/tests_samples/GermEval
--labels ./tests/fixtures/tests_samples/GermEval/labels.txt
--max_seq_length 128
--num_train_epochs 6
--logging_steps 1
--do_train
--do_eval
""".split()
with patch.object(sys, "argv", ["run.py"] + testargs):
result = run_ner.main()
self.assertLess(result["eval_loss"], 1.5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment