Unverified Commit 9068fa6c authored by Shamane Siri's avatar Shamane Siri Committed by GitHub
Browse files

Rag end2end new (#17650)

* check

* update the RAG-end2end with new PL and RAY

* removed unwanted comments
parent 53496ac5
...@@ -15,6 +15,10 @@ This code can be modified to experiment with other research on retrival augmente ...@@ -15,6 +15,10 @@ This code can be modified to experiment with other research on retrival augmente
To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used. To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used.
# Latest Update
⚠️ Updated the rag-end2end-retriever to be compatible with PL==1.6.4 and RAY==1.13.0 (latest versions to the date 2022-June-11)
# Note # Note
⚠️ This project should be run with pytorch-lightning==1.3.1 which has a potential security vulnerability ⚠️ This project should be run with pytorch-lightning==1.3.1 which has a potential security vulnerability
...@@ -22,12 +26,14 @@ To start training, use the bash script (finetune_rag_ray_end2end.sh) in this fol ...@@ -22,12 +26,14 @@ To start training, use the bash script (finetune_rag_ray_end2end.sh) in this fol
# Testing # Testing
The following two bash scripts can be used to quickly test the implementation. The following two bash scripts can be used to quickly test the implementation.
1. sh ./test_run/test_rag_new_features.sh 1. sh ./test_run/test_finetune.sh script
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
- This is sufficient to check the model's ability to use the set functions correctly.
2. sh ./test_run/test_finetune.sh script
- Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory). - Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory).
- Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning. - Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning.
- Please read the comments in the test_finetune.sh file.
2. sh ./test_run/test_rag_new_features.sh
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
- This is sufficient to check the model's ability to use the set functions correctly.
# Comparison of end2end RAG (including DPR finetuning) VS original-RAG # Comparison of end2end RAG (including DPR finetuning) VS original-RAG
......
...@@ -41,7 +41,7 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -41,7 +41,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="max", mode="max",
save_top_k=1, save_top_k=1,
every_n_val_epochs=1, # works only with PL > 1.3 every_n_epochs=1, # works only with PL > 1.3
) )
return checkpoint_callback return checkpoint_callback
......
...@@ -350,6 +350,7 @@ class GenerativeQAModule(BaseTransformer): ...@@ -350,6 +350,7 @@ class GenerativeQAModule(BaseTransformer):
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
logger.info("done updating the dataset") logger.info("done updating the dataset")
# To Do (@Aaron) : Useful in the future dynamic memory implementation.
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker. # if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
# logger.info("then updating the index") # logger.info("then updating the index")
# shutil.copy(self.custom_config.temp_index, self.config.idex_path) # shutil.copy(self.custom_config.temp_index, self.config.idex_path)
...@@ -360,10 +361,7 @@ class GenerativeQAModule(BaseTransformer): ...@@ -360,10 +361,7 @@ class GenerativeQAModule(BaseTransformer):
isEmUpdateBusy = False isEmUpdateBusy = False
isAddIndexBusy = False isAddIndexBusy = False
self.trainer.strategy.barrier("barrier")
self.trainer.accelerator_connector.accelerator.barrier(
"barrier"
) # waint untill the index and kb get re-initialized.
loss_tensors = self._step(batch) loss_tensors = self._step(batch)
...@@ -724,7 +722,7 @@ def main(args=None, model=None) -> GenerativeQAModule: ...@@ -724,7 +722,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
raise RuntimeError("Please install Ray to use the Ray distributed retriever.") raise RuntimeError("Please install Ray to use the Ray distributed retriever.")
# Connect to an existing Ray cluster. # Connect to an existing Ray cluster.
try: try:
ray.init(address=args.ray_address) ray.init(address=args.ray_address, namespace="rag")
except (ConnectionError, ValueError): except (ConnectionError, ValueError):
logger.warning( logger.warning(
"Connection to Ray cluster failed. Make sure a Ray" "Connection to Ray cluster failed. Make sure a Ray"
......
...@@ -5,7 +5,6 @@ from pathlib import Path ...@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from transformers import ( from transformers import (
...@@ -386,24 +385,22 @@ def generic_train( ...@@ -386,24 +385,22 @@ def generic_train(
train_params = {} train_params = {}
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16: if args.fp16:
train_params["precision"] = 16 train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1: if args.gpus > 1:
train_params["accelerator"] = "ddp" train_params["accelerator"] = "auto"
train_params["strategy"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
# train_params["accelerator"] = extra_train_kwargs.get("accelerator", None) train_params["profiler"] = None
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) train_params["devices"] = "auto"
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
weights_summary=None, weights_summary=None,
callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback], callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
logger=logger, logger=logger,
plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version
val_check_interval=1, val_check_interval=1,
num_sanity_val_steps=2, num_sanity_val_steps=2,
**train_params, **train_params,
...@@ -412,6 +409,6 @@ def generic_train( ...@@ -412,6 +409,6 @@ def generic_train(
if args.do_train: if args.do_train:
trainer.fit(model) trainer.fit(model)
# else: else:
# print("RAG modeling tests with new set functions successfuly executed!") print("RAG modeling tests with new set functions successfuly executed!")
return trainer return trainer
faiss-cpu >= 1.7.0 faiss-cpu >= 1.7.2
datasets >= 1.6.2 datasets
psutil >= 5.7.0 psutil >= 5.9.1
torch >= 1.4.0 torch >= 1.11.0
pytorch-lightning pytorch-lightning == 1.6.4
nvidia-ml-py3 == 7.352.0 nvidia-ml-py3 == 7.352.0
ray >= 1.3.0 ray >= 1.13.0
\ No newline at end of file
...@@ -44,11 +44,14 @@ python finetune_rag.py \ ...@@ -44,11 +44,14 @@ python finetune_rag.py \
--num_retrieval_workers 4 \ --num_retrieval_workers 4 \
--index_name custom \ --index_name custom \
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \ --context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
--index_gpus 1 \ --index_gpus 2 \
--gpu_order [6,7,8,9,0,1,2,3,5,4] \ --gpu_order [2,3,4,5,6,7,8,9,0,1] \
--indexing_freq 5 --indexing_freq 5
# Stop the Ray cluster. # Stop the Ray cluster.
ray stop ray stop
#CUDA_VISIBLE_DEVICES=2,3,4,5,6,7,8,9,0,1 sh ./test_run/test_finetune.sh
#Make sure --gpu_order is same.
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment