Unverified Commit 9068fa6c authored by Shamane Siri's avatar Shamane Siri Committed by GitHub
Browse files

Rag end2end new (#17650)

* check

* update the RAG-end2end with new PL and RAY

* removed unwanted comments
parent 53496ac5
......@@ -15,6 +15,10 @@ This code can be modified to experiment with other research on retrival augmente
To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used.
# Latest Update
⚠️ Updated the rag-end2end-retriever to be compatible with PL==1.6.4 and RAY==1.13.0 (latest versions to the date 2022-June-11)
# Note
⚠️ This project should be run with pytorch-lightning==1.3.1 which has a potential security vulnerability
......@@ -22,12 +26,14 @@ To start training, use the bash script (finetune_rag_ray_end2end.sh) in this fol
# Testing
The following two bash scripts can be used to quickly test the implementation.
1. sh ./test_run/test_rag_new_features.sh
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
- This is sufficient to check the model's ability to use the set functions correctly.
2. sh ./test_run/test_finetune.sh script
1. sh ./test_run/test_finetune.sh script
- Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory).
- Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning.
- Please read the comments in the test_finetune.sh file.
2. sh ./test_run/test_rag_new_features.sh
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
- This is sufficient to check the model's ability to use the set functions correctly.
# Comparison of end2end RAG (including DPR finetuning) VS original-RAG
......
......@@ -41,7 +41,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}",
mode="max",
save_top_k=1,
every_n_val_epochs=1, # works only with PL > 1.3
every_n_epochs=1, # works only with PL > 1.3
)
return checkpoint_callback
......
......@@ -350,6 +350,7 @@ class GenerativeQAModule(BaseTransformer):
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
logger.info("done updating the dataset")
# To Do (@Aaron) : Useful in the future dynamic memory implementation.
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
# logger.info("then updating the index")
# shutil.copy(self.custom_config.temp_index, self.config.idex_path)
......@@ -360,10 +361,7 @@ class GenerativeQAModule(BaseTransformer):
isEmUpdateBusy = False
isAddIndexBusy = False
self.trainer.accelerator_connector.accelerator.barrier(
"barrier"
) # waint untill the index and kb get re-initialized.
self.trainer.strategy.barrier("barrier")
loss_tensors = self._step(batch)
......@@ -724,7 +722,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
# Connect to an existing Ray cluster.
try:
ray.init(address=args.ray_address)
ray.init(address=args.ray_address, namespace="rag")
except (ConnectionError, ValueError):
logger.warning(
"Connection to Ray cluster failed. Make sure a Ray"
......
......@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Any, Dict
import pytorch_lightning as pl
from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
......@@ -386,24 +385,22 @@ def generic_train(
train_params = {}
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
train_params["accelerator"] = "ddp"
train_params["accelerator"] = "auto"
train_params["strategy"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
# train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None)
train_params["profiler"] = None
train_params["devices"] = "auto"
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
logger=logger,
plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version
val_check_interval=1,
num_sanity_val_steps=2,
**train_params,
......@@ -412,6 +409,6 @@ def generic_train(
if args.do_train:
trainer.fit(model)
# else:
# print("RAG modeling tests with new set functions successfuly executed!")
else:
print("RAG modeling tests with new set functions successfuly executed!")
return trainer
faiss-cpu >= 1.7.0
datasets >= 1.6.2
psutil >= 5.7.0
torch >= 1.4.0
pytorch-lightning
faiss-cpu >= 1.7.2
datasets
psutil >= 5.9.1
torch >= 1.11.0
pytorch-lightning == 1.6.4
nvidia-ml-py3 == 7.352.0
ray >= 1.3.0
ray >= 1.13.0
\ No newline at end of file
......@@ -44,11 +44,14 @@ python finetune_rag.py \
--num_retrieval_workers 4 \
--index_name custom \
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
--index_gpus 1 \
--gpu_order [6,7,8,9,0,1,2,3,5,4] \
--index_gpus 2 \
--gpu_order [2,3,4,5,6,7,8,9,0,1] \
--indexing_freq 5
# Stop the Ray cluster.
ray stop
#CUDA_VISIBLE_DEVICES=2,3,4,5,6,7,8,9,0,1 sh ./test_run/test_finetune.sh
#Make sure --gpu_order is same.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment