Unverified Commit e33085d6 authored by Shamane Siri's avatar Shamane Siri Committed by GitHub
Browse files

updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)

* updated the original RAG implementation to be compatible with the latest PL version

* updated the requirements.txt file

* execute make style

* code quality test

* code quality

* conflix resolved in requirement.txt

* code quality

* changed the MyDDP class name to CustomDDP
parent 70f88eec
import logging import logging
import os
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
) )
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp), dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="max", mode="min",
save_top_k=3, save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch. period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
) )
......
...@@ -3,7 +3,6 @@ import random ...@@ -3,7 +3,6 @@ import random
import ray import ray
from transformers import RagConfig, RagRetriever, RagTokenizer from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.file_utils import requires_datasets, requires_faiss
from transformers.models.rag.retrieval_rag import CustomHFIndex from transformers.models.rag.retrieval_rag import CustomHFIndex
...@@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever): ...@@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever):
@classmethod @classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs) config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder question_encoder_tokenizer = rag_tokenizer.question_encoder
......
...@@ -13,8 +13,8 @@ import numpy as np ...@@ -13,8 +13,8 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator import torch.distributed as torch_distrib
from pytorch_lightning.cluster_environments import TorchElasticEnvironment from pytorch_lightning.plugins.training_type import DDPPlugin
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import ( from transformers import (
...@@ -36,7 +36,6 @@ if is_ray_available(): ...@@ -36,7 +36,6 @@ if is_ray_available():
import ray import ray
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
from callbacks_rag import ( # noqa: E402 # isort:skipq from callbacks_rag import ( # noqa: E402 # isort:skipq
get_checkpoint_callback, get_checkpoint_callback,
get_early_stopping_callback, get_early_stopping_callback,
...@@ -74,27 +73,19 @@ class AttrDict(dict): ...@@ -74,27 +73,19 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule` class CustomDDP(DDPPlugin):
# is no longer used, and is moved into DDPAccelerator instead. def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
# We override DDPAccelerator to add our custom logic for initializing the module = self.model
# retriever. global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
class CustomAccel(DDPAccelerator): if not torch.distributed.is_initialized():
def __init__(self, trainer=None, **kwargs): logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
# Trainer is set later. torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
super().__init__(trainer, **kwargs)
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
module = self.trainer.model
if self.cluster_environment is None:
self.cluster_environment = TorchElasticEnvironment()
self.distributed_port = module.hparams.distributed_port
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model: if module.is_rag_model:
self.distributed_port = module.hparams.distributed_port
if module.distributed_retriever == "pytorch": if module.distributed_retriever == "pytorch":
module.model.rag.retriever.init_retrieval(self.distributed_port) module.model.rag.retriever.init_retrieval(self.distributed_port)
elif module.distributed_retriever == "ray" and global_rank == 0: elif module.distributed_retriever == "ray" and global_rank == 0:
...@@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule: ...@@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback, early_stopping_callback=es_callback,
logger=training_logger, logger=training_logger,
accelerator=CustomAccel() if args.gpus > 1 else None, custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
profiler=pl.profiler.AdvancedProfiler() if args.profile else None, profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
) )
pickle_save(model.hparams, model.output_dir / "hparams.pkl") pickle_save(model.hparams, model.output_dir / "hparams.pkl")
......
...@@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule): ...@@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode): def setup(self, stage):
if mode == "test": if stage == "test":
self.dataset_size = len(self.test_dataloader().dataset) self.dataset_size = len(self.test_dataloader().dataset)
else: else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
...@@ -341,6 +341,7 @@ def generic_train( ...@@ -341,6 +341,7 @@ def generic_train(
args: argparse.Namespace, args: argparse.Namespace,
early_stopping_callback=None, early_stopping_callback=None,
logger=True, # can pass WandbLogger() here logger=True, # can pass WandbLogger() here
custom_ddp_plugin=None,
extra_callbacks=[], extra_callbacks=[],
checkpoint_callback=None, checkpoint_callback=None,
logging_callback=None, logging_callback=None,
...@@ -370,18 +371,17 @@ def generic_train( ...@@ -370,18 +371,17 @@ def generic_train(
train_params["amp_level"] = args.fp16_opt_level train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1: if args.gpus > 1:
train_params["distributed_backend"] = "ddp" train_params["accelerator"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None) train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
weights_summary=None, weights_summary=None,
callbacks=[logging_callback] + extra_callbacks, callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
plugins=[custom_ddp_plugin],
logger=logger, logger=logger,
checkpoint_callback=checkpoint_callback,
**train_params, **train_params,
) )
......
...@@ -3,5 +3,5 @@ datasets >= 1.0.1 ...@@ -3,5 +3,5 @@ datasets >= 1.0.1
psutil >= 5.7.0 psutil >= 5.7.0
torch >= 1.4.0 torch >= 1.4.0
transformers transformers
pytorch-lightning==1.0.4 pytorch-lightning==1.3.1
GitPython GitPython
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment