Unverified Commit fe326bd5 authored by Ola Piktus's avatar Ola Piktus Committed by GitHub
Browse files

Remove dependency on examples/seq2seq from rag (#7395)


Co-authored-by: default avatarYour Name <you@example.com>
parent ad39271a
import logging import logging
import os import os
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from utils import save_json
def count_trainable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -28,3 +41,76 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -28,3 +41,76 @@ def get_checkpoint_callback(output_dir, metric):
period=0, # maybe save a checkpoint every time val is run, not just end of epoch. period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
) )
return checkpoint_callback return checkpoint_callback
def get_early_stopping_callback(metric, patience):
return EarlyStopping(
monitor=f"val_{metric}", # does this need avg?
mode="min" if "loss" in metric else "max",
patience=patience,
verbose=True,
)
class Seq2SeqLoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs)
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
# Log results
od = Path(pl_module.hparams.output_dir)
if type_path == "test":
results_file = od / "test_results.txt"
generations_file = od / "test_generations.txt"
else:
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer:
for key in sorted(metrics):
if key in ["log", "progress_bar", "preds"]:
continue
val = metrics[key]
if isinstance(val, torch.Tensor):
val = val.item()
msg = f"{key}: {val:.6f}\n"
writer.write(msg)
if not save_generations:
return
if "preds" in metrics:
content = "\n".join(metrics["preds"])
generations_file.open("w+").write(content)
@rank_zero_only
def on_train_start(self, trainer, pl_module):
try:
npars = pl_module.model.model.num_parameters()
except AttributeError:
npars = pl_module.model.num_parameters()
n_trainable_pars = count_trainable_parameters(pl_module)
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
save_json(pl_module.metrics, pl_module.metrics_save_path)
return self._write_logs(trainer, pl_module, "test")
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module):
save_json(pl_module.metrics, pl_module.metrics_save_path)
# Uncommenting this will save val generations
# return self._write_logs(trainer, pl_module, "valid")
...@@ -34,22 +34,23 @@ from transformers import logging as transformers_logging ...@@ -34,22 +34,23 @@ from transformers import logging as transformers_logging
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip from examples.rag.callbacks import ( # noqa: E402 # isort:skip
get_checkpoint_callback,
get_early_stopping_callback,
Seq2SeqLoggingCallback,
)
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from examples.rag.utils import ( # noqa: E402 # isort:skip from examples.rag.utils import ( # noqa: E402 # isort:skip
Seq2SeqDataset,
calculate_exact_match, calculate_exact_match,
is_rag_model,
set_extra_model_params,
)
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
flatten_list, flatten_list,
get_git_info, get_git_info,
is_rag_model,
lmap, lmap,
pickle_save, pickle_save,
save_git_info, save_git_info,
save_json, save_json,
set_extra_model_params,
Seq2SeqDataset,
) )
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
...@@ -303,11 +304,6 @@ class GenerativeQAModule(BaseTransformer): ...@@ -303,11 +304,6 @@ class GenerativeQAModule(BaseTransformer):
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path) dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
...@@ -315,7 +311,6 @@ class GenerativeQAModule(BaseTransformer): ...@@ -315,7 +311,6 @@ class GenerativeQAModule(BaseTransformer):
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
shuffle=shuffle, shuffle=shuffle,
num_workers=self.num_workers, num_workers=self.num_workers,
sampler=sampler,
) )
return dataloader return dataloader
...@@ -379,7 +374,6 @@ class GenerativeQAModule(BaseTransformer): ...@@ -379,7 +374,6 @@ class GenerativeQAModule(BaseTransformer):
help="The maximum total input sequence length after tokenization. Sequences longer " help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", "than this will be truncated, sequences shorter will be padded.",
) )
parser.add_argument("--sortish_sampler", action="store_true", default=False)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.") parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
......
import itertools
import json
import linecache import linecache
import os
import pickle
import re import re
import socket
import string import string
from collections import Counter from collections import Counter
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Callable, Dict, Iterable, List
import git
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from examples.seq2seq.utils import SortishSampler, trim_batch
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
...@@ -27,6 +32,19 @@ def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=Tru ...@@ -27,6 +32,19 @@ def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=Tru
) )
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class Seq2SeqDataset(Dataset): class Seq2SeqDataset(Dataset):
def __init__( def __init__(
self, self,
...@@ -114,13 +132,52 @@ class Seq2SeqDataset(Dataset): ...@@ -114,13 +132,52 @@ class Seq2SeqDataset(Dataset):
} }
return batch return batch
def make_sortish_sampler(self, batch_size):
return SortishSampler(self.src_lens, batch_size)
logger = getLogger(__name__) logger = getLogger(__name__)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str) -> None:
"""Save git information to output_dir/git_log.json"""
repo_infos = get_git_info()
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):
with open(path) as f:
return json.load(f)
def get_git_info():
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
"hostname": str(socket.gethostname()),
}
return repo_infos
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def pickle_save(obj, path):
"""pickle.dump(obj, path)"""
with open(path, "wb") as f:
return pickle.dump(obj, f)
def normalize_answer(s): def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace.""" """Lower text and remove punctuation, articles and extra whitespace."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment