Unverified Commit d944966b authored by Yoshitomo Matsubara's avatar Yoshitomo Matsubara Committed by GitHub
Browse files

Fix typos in README and bugs in RAG example code for end-to-end evaluation and finetuning (#9355)

* fix a bug in eval_batch_retrieval

* should return parser as well as other staticmethod

* remove duplicate argument

* these kwargs are no longer accepted (cause TypeError in self.generator.generate of modeling_rag.py)

* fixed file paths in README

* moved an arg to add_ray_specific_args
parent c4fd609a
......@@ -23,10 +23,10 @@ test.source
test.target
```
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
A sample finetuning command (run ` ./examples/research_projects/rag/finetune_rag.py --help` to list all available options):
```bash
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
......@@ -42,7 +42,7 @@ The `base` models initialize the question encoder with [`facebook/dpr-question_e
If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
```
python examples/rag/consolidate_rag_checkpoint.py \
python examples/research_projects/rag/consolidate_rag_checkpoint.py \
--model_type rag_sequence \
--generator_name_or_path facebook/bart-large-cnn \
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
......@@ -71,7 +71,7 @@ Also make sure to start the Ray cluster before running fine-tuning.
# Start a single-node Ray cluster.
ray start --head
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
......@@ -113,14 +113,14 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
```bash
mkdir output # or wherever you want to save this
python examples/rag/parse_dpr_relevance_data.py \
python examples/research_projects/rag/parse_dpr_relevance_data.py \
--src_path biencoder-nq-dev.json \
--evaluation_set output/biencoder-nq-dev.questions \
--gold_data_path output/biencoder-nq-dev.pages
```
3. Run evaluation:
```bash
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \
--model_type rag_sequence \
--evaluation_set output/biencoder-nq-dev.questions \
......@@ -131,7 +131,7 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
```
```bash
# EXPLANATION
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
......@@ -159,7 +159,7 @@ Add `--recalculate` parameter to force the script to perform inference from scra
An example e2e evaluation run could look as follows:
```bash
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \
--model_type rag_sequence \
--evaluation_set path/to/test.source \
......@@ -179,14 +179,14 @@ With `use_custom_knowledge_dataset.py` you can build your own knowledge source,
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
```bash
python examples/rag/use_own_knowledge_dataset.py \
python examples/research_projects/rag/use_own_knowledge_dataset.py \
--csv_path path/to/my_csv \
--output_dir path/to/my_knowledge_dataset \
```
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
```bash
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
......
......@@ -130,8 +130,6 @@ def evaluate_batch_e2e(args, rag_model, questions):
early_stopping=False,
num_return_sequences=1,
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
clean_up_tokenization=True,
print_docs=args.print_docs,
)
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
......
......@@ -443,7 +443,6 @@ class GenerativeQAModule(BaseTransformer):
type=str,
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
)
return parser
@staticmethod
......@@ -486,27 +485,10 @@ class GenerativeQAModule(BaseTransformer):
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
return parser
@staticmethod
def add_ray_specific_args(parser):
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
# Ray cluster address.
parser.add_argument(
"--ray-address",
......@@ -517,12 +499,18 @@ class GenerativeQAModule(BaseTransformer):
"cluster. Has no effect if pytorch is used as the distributed "
"retriever.",
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
return parser
def main(args=None, model=None) -> GenerativeQAModule:
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment