"benchmark/git@developer.sourcefind.cn:change/sglang.git" did not exist on "791b3bfabb61fad34e35f944e2833817d8e6929f"
Unverified Commit e33085d6 authored by Shamane Siri's avatar Shamane Siri Committed by GitHub
Browse files

updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)

* updated the original RAG implementation to be compatible with the latest PL version

* updated the requirements.txt file

* execute make style

* code quality test

* code quality

* conflix resolved in requirement.txt

* code quality

* changed the MyDDP class name to CustomDDP
parent 70f88eec
import logging
import os
from pathlib import Path
import numpy as np
......@@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp),
dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}",
mode="max",
mode="min",
save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
)
......
......@@ -3,7 +3,6 @@ import random
import ray
from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.file_utils import requires_datasets, requires_faiss
from transformers.models.rag.retrieval_rag import CustomHFIndex
......@@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever):
@classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder
......
......@@ -13,8 +13,8 @@ import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
import torch.distributed as torch_distrib
from pytorch_lightning.plugins.training_type import DDPPlugin
from torch.utils.data import DataLoader
from transformers import (
......@@ -36,7 +36,6 @@ if is_ray_available():
import ray
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
from callbacks_rag import ( # noqa: E402 # isort:skipq
get_checkpoint_callback,
get_early_stopping_callback,
......@@ -74,27 +73,19 @@ class AttrDict(dict):
self.__dict__ = self
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
# is no longer used, and is moved into DDPAccelerator instead.
# We override DDPAccelerator to add our custom logic for initializing the
# retriever.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
class CustomAccel(DDPAccelerator):
def __init__(self, trainer=None, **kwargs):
# Trainer is set later.
super().__init__(trainer, **kwargs)
class CustomDDP(DDPPlugin):
def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
module = self.model
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
if not torch.distributed.is_initialized():
logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
module = self.trainer.model
if self.cluster_environment is None:
self.cluster_environment = TorchElasticEnvironment()
self.distributed_port = module.hparams.distributed_port
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model:
self.distributed_port = module.hparams.distributed_port
if module.distributed_retriever == "pytorch":
module.model.rag.retriever.init_retrieval(self.distributed_port)
elif module.distributed_retriever == "ray" and global_rank == 0:
......@@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=training_logger,
accelerator=CustomAccel() if args.gpus > 1 else None,
custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
......
......@@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "test":
def setup(self, stage):
if stage == "test":
self.dataset_size = len(self.test_dataloader().dataset)
else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
......@@ -341,6 +341,7 @@ def generic_train(
args: argparse.Namespace,
early_stopping_callback=None,
logger=True, # can pass WandbLogger() here
custom_ddp_plugin=None,
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
......@@ -370,18 +371,17 @@ def generic_train(
train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
train_params["accelerator"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks,
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
plugins=[custom_ddp_plugin],
logger=logger,
checkpoint_callback=checkpoint_callback,
**train_params,
)
......
......@@ -3,5 +3,5 @@ datasets >= 1.0.1
psutil >= 5.7.0
torch >= 1.4.0
transformers
pytorch-lightning==1.0.4
GitPython
pytorch-lightning==1.3.1
GitPython
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment