"vscode:/vscode.git/clone" did not exist on "3a028101e91069b51629f5e74096ae78e490022b"
Unverified Commit d944966b authored by Yoshitomo Matsubara's avatar Yoshitomo Matsubara Committed by GitHub
Browse files

Fix typos in README and bugs in RAG example code for end-to-end evaluation and finetuning (#9355)

* fix a bug in eval_batch_retrieval

* should return parser as well as other staticmethod

* remove duplicate argument

* these kwargs are no longer accepted (cause TypeError in self.generator.generate of modeling_rag.py)

* fixed file paths in README

* moved an arg to add_ray_specific_args
parent c4fd609a
......@@ -23,10 +23,10 @@ test.source
test.target
```
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
A sample finetuning command (run ` ./examples/research_projects/rag/finetune_rag.py --help` to list all available options):
```bash
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
......@@ -42,7 +42,7 @@ The `base` models initialize the question encoder with [`facebook/dpr-question_e
If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
```
python examples/rag/consolidate_rag_checkpoint.py \
python examples/research_projects/rag/consolidate_rag_checkpoint.py \
--model_type rag_sequence \
--generator_name_or_path facebook/bart-large-cnn \
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
......@@ -71,7 +71,7 @@ Also make sure to start the Ray cluster before running fine-tuning.
# Start a single-node Ray cluster.
ray start --head
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
......@@ -113,14 +113,14 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
```bash
mkdir output # or wherever you want to save this
python examples/rag/parse_dpr_relevance_data.py \
python examples/research_projects/rag/parse_dpr_relevance_data.py \
--src_path biencoder-nq-dev.json \
--evaluation_set output/biencoder-nq-dev.questions \
--gold_data_path output/biencoder-nq-dev.pages
```
3. Run evaluation:
```bash
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \
--model_type rag_sequence \
--evaluation_set output/biencoder-nq-dev.questions \
......@@ -131,7 +131,7 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
```
```bash
# EXPLANATION
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
......@@ -159,7 +159,7 @@ Add `--recalculate` parameter to force the script to perform inference from scra
An example e2e evaluation run could look as follows:
```bash
python examples/rag/eval_rag.py \
python examples/research_projects/rag/eval_rag.py \
--model_name_or_path facebook/rag-sequence-nq \
--model_type rag_sequence \
--evaluation_set path/to/test.source \
......@@ -179,14 +179,14 @@ With `use_custom_knowledge_dataset.py` you can build your own knowledge source,
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
```bash
python examples/rag/use_own_knowledge_dataset.py \
python examples/research_projects/rag/use_own_knowledge_dataset.py \
--csv_path path/to/my_csv \
--output_dir path/to/my_knowledge_dataset \
```
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
```bash
python examples/rag/finetune_rag.py \
python examples/research_projects/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
......
......@@ -130,8 +130,6 @@ def evaluate_batch_e2e(args, rag_model, questions):
early_stopping=False,
num_return_sequences=1,
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
clean_up_tokenization=True,
print_docs=args.print_docs,
)
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
......
......@@ -443,7 +443,6 @@ class GenerativeQAModule(BaseTransformer):
type=str,
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
)
return parser
@staticmethod
......@@ -486,27 +485,10 @@ class GenerativeQAModule(BaseTransformer):
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
return parser
@staticmethod
def add_ray_specific_args(parser):
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
# Ray cluster address.
parser.add_argument(
"--ray-address",
......@@ -517,12 +499,18 @@ class GenerativeQAModule(BaseTransformer):
"cluster. Has no effect if pytorch is used as the distributed "
"retriever.",
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
return parser
def main(args=None, model=None) -> GenerativeQAModule:
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment