Unverified Commit 8062fa63 authored by Quentin Lhoest's avatar Quentin Lhoest Committed by GitHub
Browse files

Fix rag finetuning + add finetuning test (#8585)

* replace init_ddp_connection for index init

* style

* add finetune test

* add test data

* move generate tensors to device

* add test on EM metric

* style

* allow multi process test

* keep gloo process group for retrieval

* add multi-gpu test

* use custom accelerator

* clean test finetune

* minor

* style

* style

* typo

* use python call instead of imported main fumction

* return_dict fix in modeling_rag

* use float32 in retrieval

* store as float32 as well in the custom knowledge dataset example

* style

* rename to finetune_rag

* style

* update readme

* rename utils and callbacks to utils_rag and callbacks_rag

* fix test

* patrick's comments

* generate dummy data in the finetue test script

* remove dummy data files

* style
parent 63e91f5f
...@@ -384,6 +384,8 @@ def generic_train( ...@@ -384,6 +384,8 @@ def generic_train(
train_params["distributed_backend"] = "ddp" train_params["distributed_backend"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
......
...@@ -7,8 +7,8 @@ to the retriever to extract relevant context documents. The documents are then p ...@@ -7,8 +7,8 @@ to the retriever to extract relevant context documents. The documents are then p
Such contextualized inputs are passed to the generator. Such contextualized inputs are passed to the generator.
Read more about RAG at https://arxiv.org/abs/2005.11401. Read more about RAG at https://arxiv.org/abs/2005.11401.
# Finetuning
# Finetuning
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files: Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
```bash ```bash
...@@ -20,10 +20,10 @@ test.source ...@@ -20,10 +20,10 @@ test.source
test.target test.target
``` ```
A sample finetuning command (run ` ./examples/rag/finetune.py --help` to list all available options): A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
```bash ```bash
python examples/rag/finetune.py \ python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \ --data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \ --output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \ --model_name_or_path $MODEL_NAME_OR_PATH \
...@@ -45,7 +45,7 @@ python examples/rag/consolidate_rag_checkpoint.py \ ...@@ -45,7 +45,7 @@ python examples/rag/consolidate_rag_checkpoint.py \
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \ --question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
--dest path/to/checkpoint --dest path/to/checkpoint
``` ```
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune.py` script. You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
# Evaluation # Evaluation
...@@ -130,3 +130,29 @@ python examples/rag/eval_rag.py \ ...@@ -130,3 +130,29 @@ python examples/rag/eval_rag.py \
--print_predictions \ --print_predictions \
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists --recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
``` ```
# Use your own knowledge source
By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset.
With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG.
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
```bash
python examples/rag/use_own_knowledge_dataset.py \
--csv_path path/to/my_csv \
--output_dir path/to/my_knowledge_dataset \
```
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
```bash
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
--index_name custom
--passages_path path/to/data/my_knowledge_dataset
--index_path path/to/my_knowledge_dataset_hnsw_index.faiss
```
\ No newline at end of file
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities import rank_zero_only
from utils import save_json from utils_rag import save_json
def count_trainable_parameters(model): def count_trainable_parameters(model):
...@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric): ...@@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}", monitor=f"val_{metric}",
mode="max", mode="max",
save_top_k=3, save_top_k=3,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch. period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
) )
return checkpoint_callback return checkpoint_callback
......
...@@ -40,7 +40,6 @@ class RagPyTorchDistributedRetriever(RagRetriever): ...@@ -40,7 +40,6 @@ class RagPyTorchDistributedRetriever(RagRetriever):
generator_tokenizer=generator_tokenizer, generator_tokenizer=generator_tokenizer,
index=index, index=index,
) )
self.process_group = None self.process_group = None
def init_retrieval(self, distributed_port: int): def init_retrieval(self, distributed_port: int):
......
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py""" """Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
import argparse import argparse
import glob
import logging import logging
import os import os
import sys import sys
import time import time
import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
...@@ -15,29 +13,31 @@ import numpy as np ...@@ -15,29 +13,31 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
BartForConditionalGeneration, BartForConditionalGeneration,
BatchEncoding,
RagConfig, RagConfig,
RagSequenceForGeneration, RagSequenceForGeneration,
RagTokenForGeneration, RagTokenForGeneration,
RagTokenizer, RagTokenizer,
T5ForConditionalGeneration, T5ForConditionalGeneration,
get_linear_schedule_with_warmup,
) )
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
from callbacks import ( # noqa: E402 # isort:skipq from callbacks_rag import ( # noqa: E402 # isort:skipq
get_checkpoint_callback, get_checkpoint_callback,
get_early_stopping_callback, get_early_stopping_callback,
Seq2SeqLoggingCallback, Seq2SeqLoggingCallback,
) )
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from utils import ( # noqa: E402 # isort:skip from utils_rag import ( # noqa: E402 # isort:skip
calculate_exact_match, calculate_exact_match,
flatten_list, flatten_list,
get_git_info, get_git_info,
...@@ -67,6 +67,30 @@ class AttrDict(dict): ...@@ -67,6 +67,30 @@ class AttrDict(dict):
self.__dict__ = self self.__dict__ = self
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
# is no longer used, and is moved into DDPAccelerator instead.
# We override DDPAccelerator to add our custom logic for initializing the
# retriever.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
class CustomAccel(DDPAccelerator):
def __init__(self, trainer=None, **kwargs):
# Trainer is set later.
super().__init__(trainer, **kwargs)
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
module = self.trainer.model
if self.cluster_environment is None:
self.cluster_environment = TorchElasticEnvironment()
self.distributed_port = module.hparams.distributed_port
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model:
module.model.rag.retriever.init_retrieval(self.distributed_port)
class GenerativeQAModule(BaseTransformer): class GenerativeQAModule(BaseTransformer):
mode = "generative_qa" mode = "generative_qa"
loss_names = ["loss"] loss_names = ["loss"]
...@@ -91,23 +115,24 @@ class GenerativeQAModule(BaseTransformer): ...@@ -91,23 +115,24 @@ class GenerativeQAModule(BaseTransformer):
config = config_class.from_pretrained(hparams.model_name_or_path) config = config_class.from_pretrained(hparams.model_name_or_path)
# set retriever parameters # set retriever parameters
config.index_name = args.index_name or config.index_name config.index_name = hparams.index_name or config.index_name
config.passages_path = args.passages_path or config.passages_path config.passages_path = hparams.passages_path or config.passages_path
config.index_path = args.index_path or config.index_path config.index_path = hparams.index_path or config.index_path
config.use_dummy_dataset = hparams.use_dummy_dataset
# set extra_model_params for generator configs and load_model # set extra_model_params for generator configs and load_model
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout") extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
if self.is_rag_model: if self.is_rag_model:
if args.prefix is not None: if hparams.prefix is not None:
config.generator.prefix = args.prefix config.generator.prefix = hparams.prefix
config.label_smoothing = hparams.label_smoothing config.label_smoothing = hparams.label_smoothing
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator) hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config) retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever) model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
prefix = config.question_encoder.prefix prefix = config.question_encoder.prefix
else: else:
if args.prefix is not None: if hparams.prefix is not None:
config.prefix = args.prefix config.prefix = hparams.prefix
hparams, config = set_extra_model_params(extra_model_params, hparams, config) hparams, config = set_extra_model_params(extra_model_params, hparams, config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config) model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
prefix = config.prefix prefix = config.prefix
...@@ -152,11 +177,9 @@ class GenerativeQAModule(BaseTransformer): ...@@ -152,11 +177,9 @@ class GenerativeQAModule(BaseTransformer):
self.num_workers = hparams.num_workers self.num_workers = hparams.num_workers
self.distributed_port = self.hparams.distributed_port self.distributed_port = self.hparams.distributed_port
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True): # For single GPU training, init_ddp_connection is not called.
logger.info("Custom init_ddp_connection.") # So we need to initialize the retrievers here.
os.environ["MASTER_PORT"] = str(self.distributed_port) if hparams.gpus <= 1:
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if self.is_rag_model:
self.model.retriever.init_retrieval(self.distributed_port) self.model.retriever.init_retrieval(self.distributed_port)
def forward(self, input_ids, **kwargs): def forward(self, input_ids, **kwargs):
...@@ -270,6 +293,7 @@ class GenerativeQAModule(BaseTransformer): ...@@ -270,6 +293,7 @@ class GenerativeQAModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict: def _generative_step(self, batch: dict) -> dict:
start_time = time.time() start_time = time.time()
batch = BatchEncoding(batch).to(device=self.model.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
batch["input_ids"], batch["input_ids"],
attention_mask=batch["attention_mask"], attention_mask=batch["attention_mask"],
...@@ -322,17 +346,6 @@ class GenerativeQAModule(BaseTransformer): ...@@ -322,17 +346,6 @@ class GenerativeQAModule(BaseTransformer):
def train_dataloader(self) -> DataLoader: def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader return dataloader
def val_dataloader(self) -> DataLoader: def val_dataloader(self) -> DataLoader:
...@@ -429,10 +442,24 @@ class GenerativeQAModule(BaseTransformer): ...@@ -429,10 +442,24 @@ class GenerativeQAModule(BaseTransformer):
default=None, default=None,
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`", help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
) )
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
return parser return parser
def main(args, model=None) -> GenerativeQAModule: def main(args=None, model=None) -> GenerativeQAModule:
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
args = args or parser.parse_args()
Path(args.output_dir).mkdir(exist_ok=True) Path(args.output_dir).mkdir(exist_ok=True)
if model is None: if model is None:
model: GenerativeQAModule = GenerativeQAModule(args) model: GenerativeQAModule = GenerativeQAModule(args)
...@@ -461,6 +488,7 @@ def main(args, model=None) -> GenerativeQAModule: ...@@ -461,6 +488,7 @@ def main(args, model=None) -> GenerativeQAModule:
if args.early_stopping_patience >= 0 if args.early_stopping_patience >= 0
else False else False
) )
trainer: pl.Trainer = generic_train( trainer: pl.Trainer = generic_train(
model, model,
args, args,
...@@ -468,31 +496,17 @@ def main(args, model=None) -> GenerativeQAModule: ...@@ -468,31 +496,17 @@ def main(args, model=None) -> GenerativeQAModule:
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback, early_stopping_callback=es_callback,
logger=logger, logger=logger,
accelerator=CustomAccel() if args.gpus > 1 else None,
) )
pickle_save(model.hparams, model.output_dir / "hparams.pkl") pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict: if not args.do_predict:
return model return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
trainer.logger.log_hyperparams(model.hparams)
# test() without a model tests using the best checkpoint automatically # test() without a model tests using the best checkpoint automatically
trainer.test() trainer.test()
return model return model
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() main()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
args = parser.parse_args()
main(args)
...@@ -4,7 +4,7 @@ export PYTHONPATH="../":"${PYTHONPATH}" ...@@ -4,7 +4,7 @@ export PYTHONPATH="../":"${PYTHONPATH}"
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path # A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune.sh --help to see all the possible options # run ./examples/rag/finetune.sh --help to see all the possible options
python examples/rag/finetune.py \ python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \ --data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \ --output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \ --model_name_or_path $MODEL_NAME_OR_PATH \
......
import json
import logging
import os
import sys
from pathlib import Path
import finetune_rag
from transformers.file_utils import is_apex_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_torch_gpu,
require_torch_multi_gpu,
)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
class RagFinetuneExampleTests(TestCasePlus):
def _create_dummy_data(self, data_dir):
os.makedirs(data_dir, exist_ok=True)
contents = {"source": "What is love ?", "target": "life"}
n_lines = {"train": 12, "val": 2, "test": 2}
for split in ["train", "test", "val"]:
for field in ["source", "target"]:
content = "\n".join([contents[field]] * n_lines[split])
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
f.write(content)
def _run_finetune(self, gpus: int):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
output_dir = os.path.join(tmp_dir, "output")
data_dir = os.path.join(tmp_dir, "data")
self._create_dummy_data(data_dir=data_dir)
testargs = f"""
--data_dir {data_dir} \
--output_dir {output_dir} \
--model_name_or_path facebook/rag-sequence-base \
--model_type rag_sequence \
--do_train \
--do_predict \
--n_val -1 \
--val_check_interval 1.0 \
--train_batch_size 2 \
--eval_batch_size 1 \
--max_source_length 25 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-04 \
--num_train_epochs 1 \
--warmup_steps 4 \
--gradient_accumulation_steps 1 \
--distributed-port 8787 \
--use_dummy_dataset 1 \
""".split()
if gpus > 0:
testargs.append(f"--gpus={gpus}")
if is_apex_available():
testargs.append("--fp16")
else:
testargs.append("--gpus=0")
testargs.append("--distributed_backend=ddp_cpu")
testargs.append("--num_processes=2")
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
execute_subprocess_async(cmd, env=self.get_env())
metrics_save_path = os.path.join(output_dir, "metrics.json")
with open(metrics_save_path) as f:
result = json.load(f)
return result
@require_torch_gpu
def test_finetune_gpu(self):
result = self._run_finetune(gpus=1)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_multi_gpu
def test_finetune_multigpu(self):
result = self._run_finetune(gpus=2)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
...@@ -7,7 +7,7 @@ from tempfile import TemporaryDirectory ...@@ -7,7 +7,7 @@ from tempfile import TemporaryDirectory
from typing import List, Optional from typing import List, Optional
import torch import torch
from datasets import load_dataset from datasets import Features, Sequence, Value, load_dataset
import faiss import faiss
from transformers import ( from transformers import (
...@@ -82,10 +82,14 @@ def main( ...@@ -82,10 +82,14 @@ def main(
# And compute the embeddings # And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device) ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name) ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
new_features = Features(
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
) # optional, save as float32 instead of float64 to save space
dataset = dataset.map( dataset = dataset.map(
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer), partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
batched=True, batched=True,
batch_size=processing_args.batch_size, batch_size=processing_args.batch_size,
features=new_features,
) )
# And finally save your dataset # And finally save your dataset
......
...@@ -556,7 +556,9 @@ class RagModel(RagPreTrainedModel): ...@@ -556,7 +556,9 @@ class RagModel(RagPreTrainedModel):
if encoder_outputs is None: if encoder_outputs is None:
if has_to_retrieve: if has_to_retrieve:
question_enc_outputs = self.question_encoder(input_ids, attention_mask=attention_mask) question_enc_outputs = self.question_encoder(
input_ids, attention_mask=attention_mask, return_dict=True
)
question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
retriever_outputs = self.retriever( retriever_outputs = self.retriever(
...@@ -616,6 +618,7 @@ class RagModel(RagPreTrainedModel): ...@@ -616,6 +618,7 @@ class RagModel(RagPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
return_dict=True,
) )
if not has_to_retrieve: if not has_to_retrieve:
......
...@@ -196,7 +196,7 @@ class HFIndexBase(Index): ...@@ -196,7 +196,7 @@ class HFIndexBase(Index):
self.dataset = dataset self.dataset = dataset
self._index_initialized = index_initialized self._index_initialized = index_initialized
self._check_dataset_format(with_index=index_initialized) self._check_dataset_format(with_index=index_initialized)
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True) dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
def _check_dataset_format(self, with_index: bool): def _check_dataset_format(self, with_index: bool):
if not isinstance(self.dataset, Dataset): if not isinstance(self.dataset, Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment