Unverified Commit a4b21cdd authored by Amog Kamsetty's avatar Amog Kamsetty Committed by GitHub
Browse files

[RAG] Add Ray implementation for distributed retrieval (#9197)



* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* uncomment

* uncomment

* wip

* updates

* add docstring

* updates

* fix arg

* fixes

* add unit tests

* update readme

* update readme

* update finetune script

* update test

* add test

* add ray to test dependencies

* separate ray and ray tune

* formatting

* shutdown ray at end of test

* fix tests

* formatting

* formatting

* even more formatting

* address comments

* formatting

* add files

* Update examples/research_projects/rag/test_distributed_retriever.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address comments

* addressing comments
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-208.us-west-2.compute.internal>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent f38c4ad3
......@@ -19,3 +19,4 @@ pytest
conllu
sentencepiece != 0.1.92
protobuf
ray
......@@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
```
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
## Document Retrieval
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
with [`Ray`](https://docs.ray.io/en/master/).
This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
By default this flag is set to `pytorch`.
For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
to collect the inputs from the other training workers and send back the corresponding document embeddings.
For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
Also make sure to start the Ray cluster before running fine-tuning.
```bash
# Start a single-node Ray cluster.
ray start --head
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
--distributed_retriever ray \
--num_retrieval_workers 4
# Stop the ray cluster once fine-tuning has finished.
ray stop
```
Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.
# Evaluation
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
......
......@@ -9,6 +9,7 @@ from transformers.file_utils import is_apex_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_ray,
require_torch_gpu,
require_torch_multi_gpu,
)
......@@ -29,7 +30,7 @@ class RagFinetuneExampleTests(TestCasePlus):
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
f.write(content)
def _run_finetune(self, gpus: int):
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
......@@ -66,6 +67,7 @@ class RagFinetuneExampleTests(TestCasePlus):
--gradient_accumulation_steps 1 \
--distributed-port 8787 \
--use_dummy_dataset 1 \
--distributed_retriever {distributed_retriever} \
""".split()
if gpus > 0:
......@@ -94,3 +96,15 @@ class RagFinetuneExampleTests(TestCasePlus):
def test_finetune_multigpu(self):
result = self._run_finetune(gpus=2)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_gpu
@require_ray
def test_finetune_gpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_multi_gpu
@require_ray
def test_finetune_multigpu_ray_retrieval(self):
result = self._run_finetune(gpus=1, distributed_retriever="ray")
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
......@@ -31,14 +31,13 @@ class RagPyTorchDistributedRetriever(RagRetriever):
If specified, use this index instead of the one built using the configuration
"""
_init_retrieval = False
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.process_group = None
......
import logging
import random
import ray
from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.file_utils import requires_datasets, requires_faiss
from transformers.models.rag.retrieval_rag import CustomHFIndex
logger = logging.getLogger(__name__)
class RayRetriever:
def __init__(self):
self.initialized = False
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
if not self.initialized:
self.retriever = RagRetriever(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.initialized = True
def init_retrieval(self):
self.retriever.index.init_index()
def retrieve(self, question_hidden_states, n_docs):
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
return doc_ids, retrieved_doc_embeds
class RagRayDistributedRetriever(RagRetriever):
"""
A distributed retriever built on top of the ``Ray`` API, a library
for building distributed applications (https://docs.ray.io/en/master/).
package. During training, all training workers initialize their own
instance of a `RagRayDistributedRetriever`, and each instance of
this distributed retriever shares a common set of Retrieval Ray
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
-classes-actors) that load the index on separate processes. Ray
handles the communication between the `RagRayDistributedRetriever`
instances and the remote Ray actors. If training is done in a
non-distributed setup, the index will simply be loaded in the same
process as the training worker and Ray will not be used.
Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
These actor classes run on remote processes and are responsible for performing the index lookup.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
"""
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
raise ValueError(
"When using Ray for distributed fine-tuning, "
"you'll need to provide the paths instead, "
"as the dataset and the index are loaded "
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
)
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.retrieval_workers = retrieval_workers
if len(self.retrieval_workers) > 0:
ray.get(
[
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
for worker in self.retrieval_workers
]
)
def init_retrieval(self):
"""
Retriever initialization function, needs to be called from the
training process. This function triggers retrieval initialization
for all retrieval actors if using distributed setting, or loads
index into current process if training is not distributed.
"""
logger.info("initializing retrieval")
if len(self.retrieval_workers) > 0:
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
else:
# Non-distributed training. Load index into this same process.
self.index.init_index()
def retrieve(self, question_hidden_states, n_docs):
"""
Retrieves documents for specified ``question_hidden_states``. If
running training with multiple workers, a random retrieval actor is
selected to perform the index lookup and return the result.
Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.
Output:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
if len(self.retrieval_workers) > 0:
# Select a random retrieval actor.
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs))
else:
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
@classmethod
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
@classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer = rag_tokenizer.generator
if indexed_dataset is not None:
config.index_name = "custom"
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
else:
index = cls._build_index(config)
return cls(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
retrieval_workers=actor_handles,
index=index,
)
......@@ -29,6 +29,12 @@ from transformers import (
T5ForConditionalGeneration,
)
from transformers import logging as transformers_logging
from transformers.integrations import is_ray_available
if is_ray_available():
import ray
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
from callbacks_rag import ( # noqa: E402 # isort:skipq
......@@ -36,7 +42,8 @@ from callbacks_rag import ( # noqa: E402 # isort:skipq
get_early_stopping_callback,
Seq2SeqLoggingCallback,
)
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from utils_rag import ( # noqa: E402 # isort:skip
calculate_exact_match,
flatten_list,
......@@ -88,7 +95,12 @@ class CustomAccel(DDPAccelerator):
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model:
module.model.rag.retriever.init_retrieval(self.distributed_port)
if module.distributed_retriever == "pytorch":
module.model.rag.retriever.init_retrieval(self.distributed_port)
elif module.distributed_retriever == "ray" and global_rank == 0:
# For the Ray retriever, only initialize it once when global
# rank is 0.
module.model.rag.retriever.init_retrieval()
class GenerativeQAModule(BaseTransformer):
......@@ -127,7 +139,13 @@ class GenerativeQAModule(BaseTransformer):
config.generator.prefix = hparams.prefix
config.label_smoothing = hparams.label_smoothing
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
if hparams.distributed_retriever == "pytorch":
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
elif hparams.distributed_retriever == "ray":
# The Ray retriever needs the handles to the retriever actors.
retriever = RagRayDistributedRetriever.from_pretrained(
hparams.model_name_or_path, hparams.actor_handles, config=config
)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
prefix = config.question_encoder.prefix
else:
......@@ -180,7 +198,12 @@ class GenerativeQAModule(BaseTransformer):
# For single GPU training, init_ddp_connection is not called.
# So we need to initialize the retrievers here.
if hparams.gpus <= 1:
self.model.retriever.init_retrieval(self.distributed_port)
if hparams.distributed_retriever == "ray":
self.model.retriever.init_retrieval()
elif hparams.distributed_retriever == "pytorch":
self.model.retriever.init_retrieval(self.distributed_port)
self.distributed_retriever = hparams.distributed_retriever
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
......@@ -420,6 +443,7 @@ class GenerativeQAModule(BaseTransformer):
type=str,
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
)
return parser
@staticmethod
......@@ -442,12 +466,58 @@ class GenerativeQAModule(BaseTransformer):
default=None,
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--distributed_retriever",
choices=["ray", "pytorch"],
type=str,
default="pytorch",
help="What implementation to use for distributed retriever? If "
"pytorch is selected, the index is loaded on training "
"worker 0, and torch.distributed is used to handle "
"communication between training worker 0, and the other "
"training workers. If ray is selected, the Ray library is "
"used to create load the index on separate processes, "
"and Ray handles the communication between the training "
"workers and the retrieval actors.",
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
@staticmethod
def add_ray_specific_args(parser):
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
# Ray cluster address.
parser.add_argument(
"--ray-address",
default="auto",
type=str,
help="The address of the Ray cluster to connect to. If not "
"specified, Ray will attempt to automatically detect the "
"cluster. Has no effect if pytorch is used as the distributed "
"retriever.",
)
return parser
......@@ -461,6 +531,46 @@ def main(args=None, model=None) -> GenerativeQAModule:
args = args or parser.parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
named_actors = []
if args.distributed_retriever == "ray" and args.gpus > 1:
if not is_ray_available():
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
# Connect to an existing Ray cluster.
try:
ray.init(address=args.ray_address)
except (ConnectionError, ValueError):
logger.warning(
"Connection to Ray cluster failed. Make sure a Ray"
"cluster is running by either using Ray's cluster "
"launcher (`ray up`) or by manually starting Ray on "
"each node via `ray start --head` for the head node "
"and `ray start --address='<ip address>:6379'` for "
"additional nodes. See "
"https://docs.ray.io/en/master/cluster/index.html "
"for more info."
)
raise
# Create Ray actors only for rank 0.
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
):
remote_cls = ray.remote(RayRetriever)
named_actors = [
remote_cls.options(name="retrieval_worker_{}".format(i)).remote()
for i in range(args.num_retrieval_workers)
]
else:
logger.info(
"Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]
)
)
named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)]
args.actor_handles = named_actors
assert args.actor_handles == named_actors
if model is None:
model: GenerativeQAModule = GenerativeQAModule(args)
......@@ -471,17 +581,17 @@ def main(args=None, model=None) -> GenerativeQAModule:
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
logger = True # don't pollute wandb logs unnecessarily
training_logger = True # don't pollute wandb logs unnecessarily
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
project = os.environ.get("WANDB_PROJECT", dataset)
logger = WandbLogger(name=model.output_dir.name, project=project)
training_logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
es_callback = (
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
......@@ -495,8 +605,9 @@ def main(args=None, model=None) -> GenerativeQAModule:
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=logger,
logger=training_logger,
accelerator=CustomAccel() if args.gpus > 1 else None,
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
......@@ -509,4 +620,19 @@ def main(args=None, model=None) -> GenerativeQAModule:
if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
parser = GenerativeQAModule.add_ray_specific_args(parser)
# Pytorch Lightning Profiler
parser.add_argument(
"--profile",
action="store_true",
help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.",
)
args = parser.parse_args()
main(args)
......@@ -2,7 +2,7 @@
export PYTHONPATH="../":"${PYTHONPATH}"
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune.sh --help to see all the possible options
# run ./examples/rag/finetune_rag.sh --help to see all the possible options
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
......@@ -11,10 +11,10 @@ python examples/rag/finetune_rag.py \
--model_type rag_sequence \
--fp16 \
--gpus 8 \
--profile \
--do_train \
--do_predict \
--n_val -1 \
--val_check_interval 0.25 \
--train_batch_size 8 \
--eval_batch_size 1 \
--max_source_length 128 \
......@@ -31,4 +31,4 @@ python examples/rag/finetune_rag.py \
--learning_rate 3e-05 \
--num_train_epochs 100 \
--warmup_steps 500 \
--gradient_accumulation_steps 1
\ No newline at end of file
--gradient_accumulation_steps 1 \
# Sample script to finetune RAG using Ray for distributed retrieval.
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
# Start a single-node Ray cluster.
ray start --head
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8 \
--profile \
--do_train \
--do_predict \
--n_val -1 \
--train_batch_size 8 \
--eval_batch_size 1 \
--max_source_length 128 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-05 \
--num_train_epochs 100 \
--warmup_steps 500 \
--gradient_accumulation_steps 1 \
--distributed_retriever ray \
--num_retrieval_workers 4
# Stop the Ray cluster.
ray stop
......@@ -13,15 +13,27 @@ from datasets import Dataset
import faiss
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
from transformers.integrations import is_ray_available
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.rag.retrieval_rag import CustomHFIndex
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
if is_torch_available():
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
else:
RagPyTorchDistributedRetriever = None
if is_ray_available():
import ray # noqa: E402 # isort:skip
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip
else:
ray = None
RagRayDistributedRetriever = None
RayRetriever = None
def require_distributed_retrieval(test_case):
......@@ -32,8 +44,8 @@ def require_distributed_retrieval(test_case):
These tests are skipped when respective libraries are not installed.
"""
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
if not (is_datasets_available() and is_faiss_available() and is_psutil_available()):
test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case)
return test_case
......@@ -144,7 +156,31 @@ class RagRetrieverTest(TestCase):
retriever.init_retrieval(port)
return retriever
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever:
# Have to run in local mode because sys.path modifications at top of
# file are not propogated to remote workers.
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
ray.init(local_mode=True)
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
)
remote_cls = ray.remote(RayRetriever)
workers = [remote_cls.remote() for _ in range(1)]
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = self.get_dummy_dataset()
retriever = RagRayDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
retrieval_workers=workers,
)
if init_retrieval:
retriever.init_retrieval()
return retriever
def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
......@@ -175,13 +211,51 @@ class RagRetrieverTest(TestCase):
retriever.init_retrieval(port)
return retriever
@require_torch_non_multi_gpu_but_fix_me
def test_pytorch_distributed_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool):
# Have to run in local mode because sys.path modifications at top of
# file are not propogated to remote workers.
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
ray.init(local_mode=True)
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="custom",
)
remote_cls = ray.remote(RayRetriever)
workers = [remote_cls.remote() for _ in range(1)]
if from_disk:
config.passages_path = os.path.join(self.tmpdirname, "dataset")
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
dataset.drop_index("embeddings")
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
del dataset
retriever = RagRayDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
retrieval_workers=workers,
index=CustomHFIndex.load_from_disk(
vector_size=config.retrieval_vector_size,
dataset_path=config.passages_path,
index_path=config.index_path,
),
)
else:
retriever = RagRayDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
retrieval_workers=workers,
index=CustomHFIndex(config.retrieval_vector_size, dataset),
)
if init_retrieval:
retriever.init_retrieval()
return retriever
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None:
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
......@@ -192,33 +266,76 @@ class RagRetrieverTest(TestCase):
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
@require_torch_non_multi_gpu_but_fix_me
def test_custom_hf_index_retriever_retrieve(self):
def test_pytorch_distributed_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
self.distributed_retriever_check(
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
)
@require_torch_non_multi_gpu_but_fix_me
def test_custom_hf_index_pytorch_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False),
hidden_states,
n_docs,
)
@require_torch_non_multi_gpu_but_fix_me
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True),
hidden_states,
n_docs,
)
@require_ray
def test_ray_distributed_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
self.distributed_retriever_check(
self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs
)
ray.shutdown()
@require_ray
def test_custom_hf_index_ray_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
with self.assertRaises(ValueError):
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False),
hidden_states,
n_docs,
)
ray.shutdown()
@require_ray
def test_custom_ray_distributed_retriever_retrieve_from_disk(self):
n_docs = 1
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
self.distributed_retriever_check(
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs
)
ray.shutdown()
......@@ -219,6 +219,7 @@ from .integrations import ( # isort:skip
is_comet_available,
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_tensorboard_available,
is_wandb_available,
)
......
......@@ -63,8 +63,16 @@ try:
import ray # noqa: F401
_has_ray = True
try:
# Ray Tune has additional dependencies.
from ray import tune # noqa: F401
_has_ray_tune = True
except (ImportError):
_has_ray_tune = False
except (ImportError):
_has_ray = False
_has_ray_tune = False
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
......@@ -127,6 +135,10 @@ def is_ray_available():
return _has_ray
def is_ray_tune_available():
return _has_ray_tune
def is_azureml_available():
return _has_azureml
......@@ -143,7 +155,7 @@ def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_available():
if is_ray_tune_available():
if isinstance(trial, dict):
return trial
......@@ -153,7 +165,7 @@ def hp_params(trial):
def default_hp_search_backend():
if is_optuna_available():
return "optuna"
elif is_ray_available():
elif is_ray_tune_available():
return "ray"
......
......@@ -370,9 +370,8 @@ class RagRetriever:
"""
_init_retrieval = True
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
self._init_retrieval = init_retrieval
requires_datasets(self)
requires_faiss(self)
super().__init__()
......
......@@ -37,7 +37,7 @@ from .integrations import ( # isort: split
is_fairscale_available,
is_mlflow_available,
is_optuna_available,
is_ray_available,
is_ray_tune_available,
is_tensorboard_available,
is_wandb_available,
run_hp_search_optuna,
......@@ -145,7 +145,7 @@ if is_mlflow_available():
if is_optuna_available():
import optuna
if is_ray_available():
if is_ray_tune_available():
from ray import tune
if is_azureml_available():
......@@ -1062,7 +1062,7 @@ class Trainer:
backend = HPSearchBackend(backend)
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
if backend == HPSearchBackend.RAY and not is_ray_available():
if backend == HPSearchBackend.RAY and not is_ray_tune_available():
raise RuntimeError(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
......
......@@ -132,9 +132,9 @@ def default_hp_space_optuna(trial) -> Dict[str, float]:
def default_hp_space_ray(trial) -> Dict[str, float]:
from .integrations import is_ray_available
from .integrations import is_ray_tune_available
assert is_ray_available(), "This function needs ray installed: `pip install ray[tune]`"
assert is_ray_tune_available(), "This function needs ray installed: `pip " "install ray[tune]`"
from ray import tune
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment