Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
import logging
import torch
import datasets
from dataclasses import asdict
from transformers import (
HfArgumentParser,
)
from src.retrieval import DenseRetriever
from src.retrieval.metrics import RetrievalMetric
from src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack
from src.retrieval.args import RetrievalArgs, RetrievalTrainingArgs
from src.retrieval.data import RetrievalDataset, RetrievalDataCollator, SameDatasetTrainDataset, TASK_CONFIG
from src.utils.util import FileLogger, makedirs
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((RetrievalArgs, RetrievalTrainingArgs))
model_args, training_args = parser.parse_args_into_dataclasses()
model_args: RetrievalArgs
training_args: RetrievalTrainingArgs
config = TASK_CONFIG[model_args.version]
instruction = config["instruction"]
model = DenseRetriever(
**asdict(model_args),
cache_dir=model_args.model_cache_dir,
cos_temperature=training_args.cos_temperature,
contrastive_weight=training_args.contrastive_weight,
distill_weight=training_args.distill_weight,
teacher_temperature=training_args.teacher_temperature,
student_temperature=training_args.student_temperature,
negative_cross_device=training_args.negative_cross_device,
stable_distill=training_args.stable_distill,
)
# if model_args.train_data is not None:
# model.to(torch.float32)
if training_args.use_train_config:
model.train_config = config["training"]
tokenizer = model.tokenizer
with training_args.main_process_first():
train_dataset, task_indices_range = RetrievalDataset.prepare_train_dataset(
data_file=model_args.train_data,
cache_dir=model_args.dataset_cache_dir,
add_instruction=model_args.add_instruction,
train_group_size=training_args.train_group_size,
config=config,
use_train_config=training_args.use_train_config,
select_positive=training_args.select_positive,
select_negative=training_args.select_negative,
max_sample_num=training_args.max_sample_num,
teacher_scores_margin=training_args.teacher_scores_margin,
teacher_scores_min=training_args.teacher_scores_min,
stable_distill=training_args.stable_distill,
)
# we should get the evaluation task before specifying instruction
if model_args.eval_data is not None and model_args.add_instruction:
raw_eval_dataset = datasets.load_dataset('json', data_files=model_args.eval_data, split='train', cache_dir=model_args.dataset_cache_dir)
eval_task = raw_eval_dataset[0]["task"]
else:
eval_task = None
eval_dataset = RetrievalDataset.prepare_eval_dataset(
data_file=model_args.eval_data,
cache_dir=model_args.dataset_cache_dir,
instruction=instruction[eval_task] if eval_task is not None else None,
eval_method=training_args.eval_method,
)
corpus = RetrievalDataset.prepare_corpus(
data_file=model_args.corpus,
key_template=model_args.key_template,
cache_dir=model_args.dataset_cache_dir,
instruction=instruction[eval_task] if eval_task is not None else None
)
if training_args.process_index == 0:
# NOTE: this corpus is for computing metrics, where no instruction is given
no_instruction_corpus = RetrievalDataset.prepare_corpus(
data_file=model_args.corpus,
key_template=model_args.key_template,
cache_dir=model_args.dataset_cache_dir,
)
else:
no_instruction_corpus = None
if training_args.inbatch_same_dataset is not None:
assert training_args.dataloader_num_workers == 0, f"Make sure dataloader num_workers is 0 when using inbatch_same_dataset!"
train_dataset = SameDatasetTrainDataset(
train_dataset,
task_indices_range,
batch_size=training_args.per_device_train_batch_size,
seed=training_args.seed,
organize_method=training_args.inbatch_same_dataset,
num_processes=training_args.world_size,
process_index=training_args.process_index,
)
training_args.per_device_train_batch_size = 1
if training_args.early_exit_steps is not None:
callbacks = [EarlyExitCallBack(training_args.early_exit_steps)]
else:
callbacks = []
trainer = RetrievalTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=callbacks,
corpus=corpus,
model_args=model_args,
data_collator=RetrievalDataCollator(
tokenizer=tokenizer,
query_max_length=model_args.query_max_length,
key_max_length=model_args.key_max_length,
inbatch_same_dataset=training_args.inbatch_same_dataset
),
compute_metrics=RetrievalMetric.get_metric_fn(
model_args.metrics,
# for collecting labels
eval_data=model_args.eval_data,
cutoffs=model_args.cutoffs,
# for collecting positives and collating retrieval results
save_name=model_args.save_name,
output_dir=training_args.output_dir,
save_to_output=model_args.save_to_output,
# for restoring text from indices when collating results
corpus=no_instruction_corpus,
max_neg_num=model_args.max_neg_num,
# for nq metrics
cache_dir=model_args.dataset_cache_dir,
# for collate_neg
filter_answers=model_args.filter_answers
),
file_logger=FileLogger(makedirs(training_args.log_path))
)
# tie accelerators
model.accelerator = trainer.accelerator
# Training
if train_dataset is not None:
trainer.train()
return
if eval_dataset is not None:
trainer.evaluate()
if __name__ == "__main__":
main()
import os
import json
import logging
import random
import datasets
from tqdm import tqdm
from datetime import timedelta
from accelerate import Accelerator, InitProcessGroupKwargs
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
from collections import defaultdict
from transformers import HfArgumentParser
from src.lm import LM, LMArgs
from src.utils.util import split_file_dir_name_ext, makedirs, save_pickle, load_pickle, remove_eos, DefaultDataCollator, DatasetProcessFn
logger = logging.getLogger(__name__)
@dataclass
class ScoreArgs(LMArgs):
eval_data: str = field(
default=None,
metadata={'help': 'Query jsonl.'}
)
context_max_length: int = field(
default=1024,
metadata={'help': 'Max length for lm.'}
)
key_max_length: int = field(
default=512,
metadata={'help': 'Max length for key.'}
)
lm_batch_size: int = field(
default=4,
metadata={'help': 'Evaluation json file.'},
)
save_name: str = field(
default="llama2-7b-chat",
metadata={'help': 'Name of the scored file.'}
)
load_score: bool = field(
default=False,
metadata={'help': 'Load score from temperary file?'}
)
def process_lm_scoring(tokenizer, key_max_length=512):
test = tokenizer("test", return_special_tokens_mask=True)["special_tokens_mask"]
has_bos = has_eos = False
if test[0] == 1:
has_bos = True
if test[-1] == 1:
has_eos = True
@DatasetProcessFn(augment=True)
def _process(query, answers, query_id, task, pos=None, neg=None, history=None, context_inputs=None, query_inputs=None, answer_inputs=None, score_inputs=None, _index=None, **kwds):
"""Yield each key (pos&neg)"""
if task in ["qa", "convsearch"]:
template = "Knowledge: {key.strip()}\n\nQuestion: {query.strip()}\n\nAnswer: {answer.strip()}"
elif task == "icl":
template = "{key}\n{query}\n{answer}"
elif task == "lrlm":
# template = "{key}{continuation[i]}{context}{query}{answer}"
pass
elif task == "chat":
template = "{key}\nSpeaker 1: {query}\nSpeaker 2: {answer}"
else:
raise NotImplementedError(f"Task type {task} not implemented!")
output = defaultdict(list)
# NOTE: sample 1 answer for scoring if there are multiple
if len(answers) > 1:
answer = random.choice(answers)
else:
answer = answers[0]
if history is not None:
assert task == "chat", f"Found history={history} is not None but task={task} is not 'chat'!"
keys = history
else:
keys = pos + neg
for i, key in enumerate(keys):
# NOTE: do not add special tokens!
if task == "lrlm":
score_input = score_inputs[i]
input_ids = score_input + context_inputs + query_inputs + answer_inputs
attention_mask = [1 for _ in input_ids]
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
labels = input_ids.copy()
answer_length = len(answer_inputs)
labels[:-answer_length] = [-100] * (len(labels) - answer_length)
inputs["labels"] = labels
else:
# truncate key
key = tokenizer.decode(tokenizer.encode(key, add_special_tokens=False, max_length=key_max_length, truncation=True))
seq = eval(f"f{repr(template)}")
inputs = tokenizer(seq, return_token_type_ids=False)
if has_eos:
inputs = remove_eos(inputs, tokenizer.eos_token_id)
# find answer length
answer_seq = tokenizer.encode("Answer: " + answer.lstrip(" "), add_special_tokens=False)
answer_length = len(answer_seq) - len(tokenizer.encode("Answer:", add_special_tokens=False))
assert answer_length > 0, f"No answer found in inputs {_index}!"
# take care of padded tokens
labels = inputs["input_ids"].copy()
labels = [x if inputs["attention_mask"][i] == 1 else -100 for i, x in enumerate(labels)]
labels[:-answer_length] = [-100] * (len(labels) - answer_length)
inputs["labels"] = labels
for k, v in inputs.items():
output[k].append(v)
output["query_id"].append(query_id)
return output
return _process
def collate_scores(eval_data, save_name):
"""
Collate the lm scorings based on query_ids.
Append a 'teacher_score' column in the eval_data and save at eval_data.save_name.json.
"""
def collate(query_ids, scores):
# only on main process
eval_data_folder, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data)
data_save_path = os.path.join(eval_data_folder, f"{eval_data_name}.scored.{save_name}" + eval_data_ext)
makedirs(data_save_path)
prev_query_id = None
teacher_scores = []
try:
logger.info(f"saving data to {data_save_path}...")
with open(eval_data) as f, open(data_save_path, "w") as g:
for query_id, score in tqdm(zip(query_ids, scores)):
if (query_id != prev_query_id) and (prev_query_id is not None):
sample = json.loads(f.readline().strip())
assert prev_query_id == sample["query_id"], f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
if "history" in sample:
assert len(sample["history"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
else:
assert len(sample["pos"] + sample["neg"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['pos'] + sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
sample["teacher_scores"] = teacher_scores.copy()
if sample["task"] == "lrlm" and "query_inputs" in sample:
del sample["query_inputs"]
del sample["answer_inputs"]
del sample["context_inputs"]
del sample["score_inputs"]
g.write(json.dumps(sample, ensure_ascii=False) + "\n")
teacher_scores.clear()
# accumulate scores of different keys for the same query
# log likelihood
teacher_scores.append(-score)
prev_query_id = query_id
# NOTE: the last line
sample = json.loads(f.readline().strip())
assert prev_query_id == sample["query_id"], f"Found incompatible query_id from data ({sample['query_id']}) and from eval_preds ({prev_query_id})"
if "history" in sample:
assert len(sample["history"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['history'])}) and from eval_preds ({len(teacher_scores)})"
else:
assert len(sample["pos"] + sample["neg"]) == len(teacher_scores), f"Found incompatible key number from data ({len(sample['pos'] + sample['neg'])}) and from eval_preds ({len(teacher_scores)})"
sample["teacher_scores"] = teacher_scores.copy()
if sample["task"] == "lrlm" and "query_inputs" in sample:
del sample["query_inputs"]
del sample["answer_inputs"]
del sample["context_inputs"]
del sample["score_inputs"]
g.write(json.dumps(sample, ensure_ascii=False) + "\n")
teacher_scores.clear()
except:
save_path = os.path.join(eval_data_folder, f"{eval_data_name}.{save_name}.pkl")
logger.error(f"Error when trying to save to json file. Save scores to {save_path} instead!")
save_pickle((query_ids, scores), save_path)
raise
return collate
def main():
parser = HfArgumentParser([ScoreArgs])
args, = parser.parse_args_into_dataclasses()
args: ScoreArgs
accelerator = Accelerator(cpu=args.cpu, kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=100000))])
logger.info(f"Loading data from {args.eval_data}...")
llm = LM(
model_name_or_path=args.model_name_or_path,
dtype=args.lm_dtype,
padding_side=args.padding_side,
cache_dir=args.model_cache_dir,
accelerator=accelerator
)
llm.to(accelerator.device)
tokenizer = llm.tokenizer
logging.info(f"Loading data from {args.eval_data}...")
if args.load_score:
eval_data_folder, eval_data_name, eval_data_ext = split_file_dir_name_ext(args.eval_data)
save_path = os.path.join(eval_data_folder, f"{eval_data_name}.{args.save_name}.pkl")
results = load_pickle(save_path)
else:
with accelerator.main_process_first():
# dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train[:100]", cache_dir=args.dataset_cache_dir)
dataset = datasets.load_dataset("json", data_files=args.eval_data, split="train", cache_dir=args.dataset_cache_dir)
dataset = dataset.map(
process_lm_scoring(tokenizer=tokenizer, key_max_length=args.key_max_length),
remove_columns=dataset.column_names,
batched=True,
num_proc=32,
with_indices=True
)
data_collator = DefaultDataCollator(tokenizer=tokenizer, add_position_ids=args.add_position_ids)
dataloader = DataLoader(
dataset,
batch_size=args.lm_batch_size,
collate_fn=data_collator,
pin_memory=True,
)
dataloader = accelerator.prepare(dataloader)
query_ids, scores = llm.compute_nlls(dataloader)
if accelerator.process_index == 0:
collate_scores(args.eval_data, args.save_name)(query_ids, scores)
if __name__ == "__main__":
main()
import logging
import datasets
from dataclasses import asdict
from transformers import (
HfArgumentParser,
)
from src.retrieval import CrossEncoder
from src.retrieval.metrics import RetrievalMetric
from src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack
from src.retrieval.args import RankerArgs, RetrievalTrainingArgs
from src.retrieval.data import RetrievalDataset, RetrievalDataCollator, SameDatasetTrainDataset, TASK_CONFIG
from src.utils.util import FileLogger, makedirs
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((RankerArgs, RetrievalTrainingArgs))
model_args, training_args = parser.parse_args_into_dataclasses()
model_args: RankerArgs
training_args: RetrievalTrainingArgs
# set to rerank
training_args.eval_method = "rerank"
config = TASK_CONFIG[model_args.version]
instruction = config["instruction"]
if model_args.ranker_method == "cross-encoder":
model = CrossEncoder(
ranker=model_args.ranker,
# NOTE: the fp16 model cannot be trained
# dtype="fp32" if model_args.train_data is not None else model_args.dtype,
dtype=model_args.dtype,
cache_dir=model_args.model_cache_dir,
)
cross = True
else:
raise NotImplementedError(f"Ranker method {model_args.ranker_method} not implemented!")
if training_args.use_train_config:
model.train_config = config["training"]
tokenizer = model.tokenizer
with training_args.main_process_first():
train_dataset, task_indices_range = RetrievalDataset.prepare_train_dataset(
data_file=model_args.train_data,
cache_dir=model_args.dataset_cache_dir,
add_instruction=model_args.add_instruction,
train_group_size=training_args.train_group_size,
config=config,
use_train_config=training_args.use_train_config,
select_positive=training_args.select_positive,
select_negative=training_args.select_negative,
max_sample_num=training_args.max_sample_num,
teacher_scores_margin=training_args.teacher_scores_margin,
teacher_scores_min=training_args.teacher_scores_min,
)
# we should get the evaluation task before specifying instruction
if model_args.eval_data is not None and model_args.add_instruction:
raw_eval_dataset = datasets.load_dataset('json', data_files=model_args.eval_data, split='train', cache_dir=model_args.dataset_cache_dir)
eval_task = raw_eval_dataset[0]["task"]
else:
eval_task = None
eval_dataset = RetrievalDataset.prepare_eval_dataset(
data_file=model_args.eval_data,
cache_dir=model_args.dataset_cache_dir,
instruction=instruction[eval_task] if eval_task is not None else None,
eval_method=training_args.eval_method,
)
corpus = RetrievalDataset.prepare_corpus(
data_file=model_args.corpus,
key_template=model_args.key_template,
cache_dir=model_args.dataset_cache_dir,
instruction=instruction[eval_task] if eval_task is not None else None
)
if training_args.process_index == 0:
# NOTE: this corpus is for computing metrics, where no instruction is given
no_instruction_corpus = RetrievalDataset.prepare_corpus(
data_file=model_args.corpus,
key_template=model_args.key_template,
cache_dir=model_args.dataset_cache_dir,
)
else:
no_instruction_corpus = None
if training_args.inbatch_same_dataset is not None:
assert training_args.dataloader_num_workers == 0, f"Make sure dataloader num_workers is 0 when using inbatch_same_dataset!"
train_dataset = SameDatasetTrainDataset(
train_dataset,
task_indices_range,
batch_size=training_args.per_device_train_batch_size,
seed=training_args.seed,
organize_method=training_args.inbatch_same_dataset,
num_processes=training_args.world_size,
process_index=training_args.process_index,
)
training_args.per_device_train_batch_size = 1
if training_args.early_exit_steps is not None:
callbacks = [EarlyExitCallBack(training_args.early_exit_steps)]
else:
callbacks = []
trainer = RetrievalTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=callbacks,
corpus=corpus,
model_args=model_args,
data_collator=RetrievalDataCollator(
tokenizer=tokenizer,
query_max_length=model_args.query_max_length,
key_max_length=model_args.key_max_length,
inbatch_same_dataset=training_args.inbatch_same_dataset,
cross=cross
),
compute_metrics=RetrievalMetric.get_metric_fn(
model_args.metrics,
# for collecting labels
eval_data=model_args.eval_data,
cutoffs=model_args.cutoffs,
# for collecting positives and collating retrieval results
save_name=model_args.save_name,
output_dir=training_args.output_dir,
save_to_output=model_args.save_to_output,
# for restoring text from indices when collating results
corpus=no_instruction_corpus,
max_neg_num=model_args.max_neg_num,
# for nq metrics
cache_dir=model_args.dataset_cache_dir,
# for collate_neg
filter_answers=model_args.filter_answers
),
file_logger=FileLogger(makedirs(training_args.log_path)),
)
# tie accelerators
model.accelerator = trainer.accelerator
# Training
if train_dataset is not None:
trainer.train()
return
if eval_dataset is not None:
trainer.evaluate()
if __name__ == "__main__":
main()
# the instruction and training config version
version="llm-embedder"
# the output folder
output="llm-embedder"
# the data root where you untar the data
data_root="/data/llm-embedder"
torchrun --nproc_per_node=8 run_dense.py --train_data \
llm-embedder:chat/msc/train.json \
llm-embedder:convsearch/qrecc/train.concat.json \
llm-embedder:lrlm/arxiv/train.json \
llm-embedder:lrlm/books3/train.json \
llm-embedder:lrlm/codeparrot/train.json \
llm-embedder:qa/msmarco/train.json \
llm-embedder:qa/nq/train.json \
llm-embedder:tool/toolbench/train.json \
llm-embedder:tool/toolbench/train.json \
llm-embedder:icl/icl/train.json \
--output_dir data/outputs/$output \
--save_steps 10000 \
--max_steps 10000 \
--logging_steps 100 \
--inbatch_same_dataset epoch \
--use_train_config \
--gradient_checkpointing \
--per_device_train_batch_size 100 \
--deepspeed data/deepspeed/stage0.json \
--version $version \
--learning_rate 5e-6 \
--data_root $data_root
for model in "checkpoint-10000"
do
torchrun --nproc_per_node 8 -m evaluation.eval_mmlu --query_encoder data/outputs/$output/$model/encoder --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_popqa --query_encoder data/outputs/$output/$model/encoder --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_msc --query_encoder data/outputs/$output/$model/encoder --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_tool --query_encoder data/outputs/$output/$model/encoder --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_lrlm --query_encoder data/outputs/$output/$model/encoder --eval_data llm-embedder:lrlm/books3/test.json --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_lrlm --query_encoder data/outputs/$output/$model/encoder --eval_data llm-embedder:lrlm/arxiv/test.json --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_lrlm --query_encoder data/outputs/$output/$model/encoder --eval_data llm-embedder:lrlm/codeparrot/test.json --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_lrlm --query_encoder data/outputs/$output/$model/encoder --eval_data llm-embedder:lrlm/pg19/test.json --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_icl --query_encoder data/outputs/$output/$model/encoder --version $version --data_root $data_root
torchrun --nproc_per_node 8 -m evaluation.eval_qrecc --query_encoder data/outputs/$output/$model/encoder --version $version --data_root $data_root
done
import os
from typing import Optional, List
from dataclasses import dataclass, field
from sentence_transformers import models, SentenceTransformer
from transformers import HfArgumentParser
def convert_ours_ckpt_to_sentence_transformer(src_dir, dest_dir, pooling_method: List[str] = ['cls'], dense_metric: str="cos"):
assert os.path.exists(src_dir), f"Make sure the encoder path {src_dir} is valid on disk!"
assert "decoder" not in pooling_method, f"Pooling method 'decode' cannot be saved as sentence_transformers because it uses the decoder stack to produce sentence embedding."
if dest_dir is None:
dest_dir = src_dir
print(f"loading model from {src_dir} and saving the sentence_transformer model at {dest_dir}...")
word_embedding_model = models.Transformer(src_dir)
modules = [word_embedding_model]
ndim = word_embedding_model.get_word_embedding_dimension()
if "cls" in pooling_method:
pooling_model = models.Pooling(ndim, pooling_mode="cls")
pooling_method.remove("cls")
elif "mean" in pooling_method:
pooling_model = models.Pooling(ndim, pooling_mode="mean")
pooling_method.remove("mean")
else:
raise NotImplementedError(f"Fail to find cls or mean in pooling_method {pooling_method}!")
modules.append(pooling_model)
if "dense" in pooling_method:
modules.append(models.Dense(ndim, ndim, bias=False))
pooling_method.remove("dense")
assert len(pooling_method) == 0, f"Found unused pooling_method {pooling_method}!"
if dense_metric == "cos":
normalize_layer = models.Normalize()
modules.append(normalize_layer)
model = SentenceTransformer(modules=modules, device='cpu')
model.save(dest_dir)
@dataclass
class Args:
encoder: Optional[str] = field(
default=None,
metadata={'help': 'Path to the encoder model.'}
)
output_dir: Optional[str] = field(
default=None,
metadata={'help': 'Path to the output sentence_transformer model.'}
)
pooling_method: List[str] = field(
default_factory=lambda: ["cls"],
metadata={'help': 'Pooling methods to aggregate token embeddings for a sequence embedding. {cls, mean, dense, decoder}'}
)
dense_metric: str = field(
default="cos",
metadata={'help': 'What type of metric for dense retrieval? ip, l2, or cos.'}
)
model_cache_dir: Optional[str] = field(
default=None,
metadata={'help': 'Cache folder for huggingface transformers.'}
)
def __post_init__(self):
convert_ours_ckpt_to_sentence_transformer(self.encoder, self.output_dir, self.pooling_method, self.dense_metric)
if __name__ == "__main__":
parser = HfArgumentParser([Args])
args, = parser.parse_args_into_dataclasses()
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
# import transformers
# transformers.logging.set_verbosity_error()
from .args import LMArgs, SRLMArgs, GenerationArgs
from .modeling_lm import LM
from .modeling_srlm import SelfRetrievalLM
from dataclasses import dataclass, field
from typing import Optional, List
from ..retrieval.args import BaseArgs
@dataclass
class LMArgs(BaseArgs):
model_name_or_path: str = field(
default='meta-llama/Llama-2-7b-chat-hf',
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
padding_side: str = field(
default="left",
metadata={'help': 'Tokenizer padding side.'}
)
truncation_side: str = field(
default="right",
metadata={'help': 'Tokenizer truncation side.'}
)
context_max_length: int = field(
default=2048,
metadata={'help': 'Evaluation json file.'},
)
add_position_ids: bool = field(
default=False,
metadata={'help': 'Create position ids based on attention masks? Useful when training left-padded models with absolute position embeddings.'}
)
lm_dtype: str = field(
default="bf16",
metadata={'help': 'Data type for embeddings.'}
)
lm_device_map: Optional[str] = field(
default=None,
metadata={'help': 'Device map for loading the model. Set to auto to load across devices.'}
)
lm_batch_size: int = field(
default=2,
metadata={'help': 'Evaluation batch size.'},
)
cpu: bool = field(
default=False,
metadata={'help': 'Use cpu?'}
)
add_llama_inst: bool = field(
default=False,
metadata={'help': 'Add llama2-chat instructions? ([INST] and [/INST])'}
)
@dataclass
class SRLMArgs(LMArgs):
context_max_length: int = field(
default=4096,
metadata={'help': 'How many tokens in total as inputs?'}
)
context_window_size: int = field(
default=2048,
metadata={'help': 'How many tokens the model can process at the same time?'}
)
target_length: int = field(
default=1024,
metadata={'help': 'How many tokens to compute perplexity?'}
)
chunk_size: int = field(
default=128,
metadata={'help': 'How many tokens in a chunk?'}
)
key_num: int = field(
default=1,
metadata={'help': 'How many chunks to retrieve at a time?'}
)
chunk_batch_size: int = field(
default=2,
metadata={'help': 'How many retrieval & generation to execute in parallel?'}
)
add_key_continuation: bool = field(
default=False,
metadata={'help': 'Add continuation as keys?'}
)
retrieval_method: str = field(
default='dense',
metadata={'help': 'How to retrieve?'}
)
order_method: str = field(
default='sequential',
metadata={'help': 'How to retrieve?'}
)
integrate_method: str = field(
default="concat",
metadata={'help': 'How to integrate retrieved chunks. Replace: replace the most distant chunks. Concat: concatenate at the beginning.'}
)
add_sep: Optional[List[int]] = field(
default=None,
metadata={'help': 'The tokens to add after retrieved chunks. "none" means no sep.'}
)
@dataclass
class GenerationArgs:
do_sample: bool = field(
default=False,
metadata={'help': 'Sample when decoding?'}
)
num_return_sequences: int = field(
default=1,
metadata={'help': 'How many sequences to generate?'}
)
temperature: float = field(
default=1.0,
metadata={'help': 'Temperature for sampling'}
)
top_p: Optional[float] = field(
default=1.0,
metadata={'help': 'Top-p sampling value'}
)
max_new_tokens: Optional[int] = field(
default=32,
metadata={'help': 'Maximum new token number.'}
)
eos_token_id: Optional[int] = field(
default=None,
metadata={'help': 'End of sequence token id.'}
)
_from_model_config: bool = field(
default=False,
metadata={'help': 'Load generation config from model config?'}
)
def __post_init__(self):
if self.temperature == 0:
self.temperature = 1e-8
\ No newline at end of file
import torch
import logging
from tqdm import tqdm
from accelerate import Accelerator
from typing import Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
logger = logging.getLogger(__name__)
class LM(torch.nn.Module):
def __init__(self, model_name_or_path=None, padding_side="left", dtype="bf16", cache_dir="/share/LMs", device_map=None, accelerator: Accelerator=None, generation_args: Dict=None) -> None:
super().__init__()
logger.info(f"loading tokenizer and model from {model_name_or_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir, padding_side=padding_side, trust_remote_code=True)
if tokenizer.pad_token is None:
# NOTE: for models like Qwen, there is no pre-defined eos tokens
if tokenizer.eos_token is None:
pad_token = "<|endoftext|>"
else:
pad_token = tokenizer.eos_token
tokenizer.pad_token = pad_token
self.tokenizer = tokenizer
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
dtype = torch.float32
self.accelerator = accelerator
try:
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, cache_dir=cache_dir, torch_dtype=dtype, trust_remote_code=True, device_map=device_map)
except ValueError:
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, cache_dir=cache_dir, torch_dtype=dtype, trust_remote_code=True, device_map=device_map)
# if device_map is specified, we don't need to move the model to any specific gpu
if device_map is None:
if accelerator is not None:
device = accelerator.device
else:
device = torch.device("cpu")
self.model.to(device)
# update the model's default generation config
if generation_args is not None:
generation_config = self.model.generation_config.to_dict()
generation_config.update(generation_args)
generation_config.update({
"pad_token_id": self.tokenizer.pad_token_id
})
self.model.generation_config = GenerationConfig(**generation_config)
@property
def device(self):
if self.accelerator is not None:
return self.accelerator.device
else:
return torch.device("cpu")
def _move_to_device(self, inputs):
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.device)
return inputs
@torch.no_grad()
def compute_nlls(self, dataloader):
self.model.eval()
all_query_ids = []
all_nlls = []
for step, inputs in enumerate(tqdm(dataloader, desc='Computing NLLs')):
# move to gpu
inputs = self._move_to_device(inputs)
return_query_id = False
if 'query_id' in inputs:
query_id = inputs.pop("query_id") # batch_size
return_query_id = True
outputs = self.model(**inputs)
if self.model.config.is_encoder_decoder:
shifted_logits = outputs.logits
shifted_labels = inputs["labels"]
else:
shifted_logits = outputs.logits[:, :-1].contiguous() # batch_size, seq_len - 1, vocab_size
shifted_labels = inputs["labels"][:, 1:].contiguous() # batch_size, seq_len - 1, vocab_size
batch_size = shifted_logits.shape[0]
token_loss = torch.nn.functional.cross_entropy(shifted_logits.flatten(0, 1), shifted_labels.view(-1), reduction="none").reshape(batch_size, -1) # batch_size, seq_len - 1
batch_loss = token_loss.sum(-1) # batch_size
valid_token_num = (inputs["labels"] != -100).sum(-1) # batch_size
nll = batch_loss / valid_token_num # batch_size
if self.accelerator is not None:
if return_query_id:
query_id = self.accelerator.gather_for_metrics(query_id)
nll = self.accelerator.gather_for_metrics(nll)
all_nlls.extend(nll.tolist())
if return_query_id:
all_query_ids.extend(query_id.tolist())
# print(outputs.loss)
# print(self.tokenizer.batch_decode(inputs["input_ids"]))
# labels = inputs["labels"]
# labels[labels == -100] = 0
# print(self.tokenizer.batch_decode(labels))
# print(all_nlls)
# input()
if return_query_id:
return all_query_ids, all_nlls
return all_nlls
@torch.no_grad()
def generate(self, dataloader, return_new_tokens_only=True, decode=True, **gen_kwargs):
self.model.eval()
all_query_ids = []
all_generations = []
for step, inputs in enumerate(tqdm(dataloader, desc='Generating')):
# move to gpu
inputs = self._move_to_device(inputs)
return_query_id = False
if 'query_id' in inputs:
query_id = inputs.pop("query_id") # batch_size
return_query_id = True
outputs = self.model.generate(**inputs, **gen_kwargs)
if return_new_tokens_only:
if self.model.config.is_encoder_decoder:
if "decoder_input_ids" in inputs:
start_idx = inputs["decoder_input_ids"].shape[1] + 1
else:
start_idx = 1
else:
start_idx = inputs["input_ids"].shape[1]
outputs = outputs[:, start_idx:]
if self.accelerator is not None:
if return_query_id:
query_id = self.accelerator.gather_for_metrics(query_id)
# must be contiguous
outputs = outputs.contiguous()
# FIXME: dim cannot be -1
outputs = self.accelerator.pad_across_processes(outputs, pad_index=self.tokenizer.pad_token_id, dim=1)
outputs = self.accelerator.gather_for_metrics(outputs)
outputs = outputs.tolist()
if decode:
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
all_generations.extend(outputs)
if return_query_id:
query_id = query_id.tolist()
all_query_ids.extend(query_id)
if return_query_id:
return all_query_ids, all_generations
return all_generations
import torch
import math
import logging
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from accelerate import Accelerator
from dataclasses import dataclass
from typing import Optional, Tuple, List, Dict
from transformers.modeling_utils import ModelOutput
from .modeling_lm import LM
from ..utils.util import save_pickle, load_pickle
logger = logging.getLogger(__name__)
@dataclass
class SRLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
class SelfRetrievalLM(LM):
def __init__(self, retriever=None, context_window_size:int=2048, chunk_size:int=64, key_num:int=1, chunk_batch_size:int=2, add_key_continuation=False, retrieval_method="dense", order_method:str="sequential", integrate_method:str="concat", instruction:Dict=None, add_sep:Optional[List[int]]=None, debug_retrieval:bool=False, **kwds) -> None:
super().__init__(**kwds)
self.retriever = retriever
assert context_window_size % chunk_size == 0, f"Make sure the context_window_size ({context_window_size}) is divisible by chunk_size ({chunk_size})!"
self.context_window_size = context_window_size
self.chunk_size = chunk_size
self.chunk_batch_size = chunk_batch_size
self.key_num = key_num
self.add_sep = add_sep
self.add_key_continuation = add_key_continuation
self.retrieval_method = retrieval_method
self.order_method = order_method
self.integrate_method = integrate_method
self.debug_retrieval = debug_retrieval
self.instruction = instruction
if self.add_sep is not None:
logger.warning(f"will add {add_sep} after retrieved chunks!")
self.register_buffer("sep_token_ids", torch.tensor(add_sep), persistent=False)
def _get_retrieved_chunks(self, value_chunks, retrieved_indices):
"""Get the retrieved chunks and their continuations according to retrieved_indices."""
batch_size = value_chunks.shape[0]
chunk_batch_size = retrieved_indices.shape[0] // batch_size
# NOTE: by default, the retrieved_indices are sorted descendingly according to relevance
if self.order_method == "sequential":
retrieved_indices = retrieved_indices.sort(-1)[0]
elif self.order_method == "relevance":
retrieved_indices = retrieved_indices.flip(dims=(-1,))
else:
raise NotImplementedError(f"Order strategy {self.order_method} not implemented!")
indices = retrieved_indices.repeat_interleave(2, -1) # batch_size * chunk_batch_size, 2k
indices[:, 1::2] += 1
indices = indices[..., None].expand(batch_size * chunk_batch_size, 2 * self.key_num, self.chunk_size) # batch_size * chunk_batch_size, 2k, chunk_size
# Slice out the retrieved chunk and its continuation from the corpus
retrieved_chunks = value_chunks.repeat_interleave(chunk_batch_size, dim=0).gather(dim=1, index=indices).view(indices.shape[0], self.key_num, 2 * self.chunk_size) # batch_size * chunk_batch_size, k, 2 * chunk_size
if self.add_sep is not None:
retrieved_chunks[..., -len(self.sep_token_ids):] = self.sep_token_ids
retrieved_chunks = retrieved_chunks.flatten(-2, -1)
return retrieved_chunks, retrieved_indices
def _get_retrieved_history(self, history, retrieved_indices):
"""Get the retrieved history according to retrieved_indices."""
batch_size = history.shape[0]
if retrieved_indices is None:
retrieved_history = np.array([""] * (batch_size))
else:
if isinstance(retrieved_indices, torch.Tensor):
retrieved_indices = retrieved_indices.cpu().numpy()
elif isinstance(retrieved_indices, np.ndarray):
pass
# NOTE: by default, the retrieved_indices are sorted descendingly according to relevance
if self.order_method == "sequential":
retrieved_indices.sort(axis=-1)
elif self.order_method == "relevance":
retrieved_indices = retrieved_indices[...,::-1]
else:
raise NotImplementedError(f"Order strategy {self.order_method} not implemented!")
# slice out retrieved histories
retrieved_history = np.take_along_axis(history, indices=retrieved_indices, axis=-1)
# FIXME: I think maybe there is better way to concatenate the strings row-wise
retrieved_history = np.array(["\n".join(x) for x in retrieved_history])
# Last /n is important
retrieved_history = np.char.add(retrieved_history, ["\n"] * batch_size)
return retrieved_history
def forward(self, **kwds):
if "history" in kwds:
return self.forward_with_history_retrieval(**kwds)
else:
return self.forward_with_chunk_retrieval(**kwds)
def forward_with_history_retrieval(self, query:np.ndarray, history:np.ndarray, answer:np.ndarray, history_mask:torch.Tensor):
batch_size = len(query)
query_with_prompt = np.char.add(["Speaker 1: "] * batch_size, query)
answer_with_prompt = np.char.add(["\nSpeaker 2: "] * batch_size, answer)
# get answer length
answer_length = self.tokenizer(answer.tolist(), padding=True, return_tensors="pt", return_token_type_ids=False, add_special_tokens=False)["attention_mask"].sum(-1, keepdim=True).to(self.device)
history_size = history.shape[1]
if self.retrieval_method == "no":
retrieved_indices = None
elif self.retrieval_method == "random":
retrieved_indices = np.random.randint(0, history_size, (batch_size, self.key_num))
elif self.retrieval_method == "recent":
valid_history_num = history_mask.cpu().numpy().sum(axis=-1)
valid_history_num = np.maximum(valid_history_num, self.key_num)
start_idx = valid_history_num - self.key_num
arange = np.arange(self.key_num)[None, :]
retrieved_indices = arange + start_idx # batch_size, key_num
elif self.retrieval_method == "dense":
# masking the padded history
history_mask = history_mask.to(self.device)
if self.instruction is not None:
queries = np.char.add([self.instruction["query"]] * batch_size, query)
keys = np.char.add([self.instruction["key"]] * batch_size, history.reshape(-1))
else:
queries = query
keys = history.reshape(-1)
history_embedding = self.retriever.encode(keys.tolist()).unflatten(0, (batch_size, history_size)) # B * N, D
context_embedding = self.retriever.encode(queries.tolist()) # B, D
scores = torch.einsum("bnd,bd->bn", history_embedding, context_embedding) # B, N
# mask padded histories
scores = scores.masked_fill(~history_mask, torch.finfo(scores.dtype).min)
_, retrieved_indices = scores.topk(k=self.key_num, dim=-1) # B, K
elif self.retrieval_method == "bm25":
retrieved_indices = np.zeros(batch_size, self.key_num, dtype=np.int32)
for batch_idx in range(batch_size):
bm25 = deepcopy(self.retriever)
bm25.index(history[batch_idx].tolist())
_, indice = bm25.search(query[batch_idx].tolist(), hits=self.key_num)
retrieved_indices[batch_idx] = indice[0]
elif self.retrieval_method == "oracle":
assert self.key_num == 1 and batch_size == 1, f"Retrieval_method 'oracle' is only available when k == 1 and batch_size == 1!"
min_loss = 1e3
min_k = 0
min_outputs = None
for hist_idx in range(history_size):
hist = history[:, hist_idx]
inputs = np.char.add(hist, ["\n"] * batch_size)
inputs = np.char.add(inputs, query_with_prompt)
inputs = np.char.add(inputs, answer_with_prompt)
inputs = self.tokenizer(inputs.tolist(), padding=True, truncation=True, max_length=self.context_window_size, return_tensors="pt", return_token_type_ids=False).to(self.device)
labels = inputs["input_ids"].clone()
arange = torch.arange(labels.shape[1] - 1, -1, -1, device=self.device).expand(labels.shape)
labels_mask = arange >= answer_length
inputs["labels"] = labels.masked_fill(labels_mask, -100)
outputs = self.model(**inputs)
loss = outputs.loss
# print(self.tokenizer.batch_decode(labels.masked_fill(labels_mask, self.tokenizer.pad_token_id)))
# print(inputs["input_ids"])
# print(inputs["labels"])
# save_pickle(inputs.to("cpu"), "debug.pkl")
# print(loss)
# input()
if loss < min_loss:
min_loss = loss
min_k = hist_idx
min_outputs = outputs
if self.debug_retrieval:
print(min_k)
print(f"***Query***\n{query[0].tolist()}")
print(f"***Answer***\n{answer[0].tolist()}")
print(f"***Retrieved***\n{history[0, min_k].tolist()}")
print(outputs.loss)
input()
return min_outputs
else:
raise NotImplementedError(f"Retrieval method {self.retrieval_method} not implemented!")
retrieved_history = self._get_retrieved_history(history, retrieved_indices)
# combine retrieved turns with the current context
inputs = np.char.add(retrieved_history, query_with_prompt)
inputs = np.char.add(inputs, answer_with_prompt)
inputs = self.tokenizer(inputs.tolist(), padding=True, truncation=True, max_length=self.context_window_size, return_tensors="pt", return_token_type_ids=False).to(self.device)
labels = inputs["input_ids"].clone()
arange = torch.arange(labels.shape[1] - 1, -1, -1, device=self.device).expand(labels.shape)
labels_mask = arange >= answer_length
inputs["labels"] = labels.masked_fill(labels_mask, -100)
# print(self.tokenizer.batch_decode(labels.masked_fill(labels_mask, self.tokenizer.pad_token_id)))
outputs = self.model(**inputs)
if self.debug_retrieval:
for i in range(batch_size):
print(f"***Query***\n{query[i].tolist()}")
print(f"***Answer***\n{answer[i].tolist()}")
print(f"***Retrieved***\n{retrieved_history[i].tolist()}")
print(outputs.loss)
input()
return outputs
def forward_with_chunk_retrieval(self, input_ids, attention_mask, labels):
batch_size, inputs_length = input_ids.shape
# in this case, all inputs are visible to the language model, thus no retrieval needed
if self.retrieval_method == "no":
input_ids = input_ids[:, -self.context_window_size:]
attention_mask = attention_mask[:, -self.context_window_size:]
labels = labels[:, -self.context_window_size:]
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
return outputs
# Pad inputs to multiple of chunk_size
num_chunks = math.ceil(inputs_length / self.chunk_size)
# NOTE: get the minor one because some inputs may be shorter than context_window_size even after padding to multiple of chunk size
context_window_size = min(num_chunks * self.chunk_size, self.context_window_size)
if inputs_length % self.chunk_size != 0:
pad_length = num_chunks * self.chunk_size - inputs_length
input_ids = torch.cat([input_ids.new_zeros(batch_size, pad_length) + self.tokenizer.pad_token_id, input_ids], dim=-1)
attention_mask = torch.cat([attention_mask.new_zeros(batch_size, pad_length), attention_mask], dim=-1)
labels = torch.cat([labels.new_zeros(batch_size, pad_length) - 100, labels], dim=-1)
inputs_length = input_ids.shape[1]
# Find the start of target. All retrieval operation starts from the preceeding chunk to the target
is_valid = (labels != -100).float()
target_start_index = is_valid.argmax(-1)
assert (target_start_index == target_start_index[0]).all(), f"Make sure all targets in the batch starts from the same token index!"
target_start_index = target_start_index[0].item()
assert target_start_index % self.chunk_size == 0, f"Make sure the target_length ({inputs_length} - {target_start_index} = {inputs_length - target_start_index}) is divisible by chunk_size ({self.chunk_size})!"
# Organize inputs
n_target_chunk = (inputs_length - target_start_index) // self.chunk_size
n_window_chunk = context_window_size // self.chunk_size
input_ids = input_ids.view(batch_size, -1, self.chunk_size)
labels = labels[:, -context_window_size:]
# print(labels)
# Split queries, keys and values
# the chunk preceeding target is the first query
query_chunks = input_ids[:, -n_target_chunk - 1: -1]
if self.integrate_method == "replace":
assert n_window_chunk >= (n_target_chunk + 1 + 2 * self.key_num), f"Make sure there are at least k * 2 + 1 + n_target_chunk = {self.key_num * 2 + 1 + n_target_chunk} chunks (found {context_window_size} / {self.chunk_size} = {n_window_chunk}) that can be replaced with retrieved contents!"
# these tokens will be directly concatenated with retrieved chunks
fixed_context = input_ids[:, -n_window_chunk + 2 * self.key_num:]
# besides previous chunks, the last chunk is also taken as keys because
# we only want to replace the context when there are more relevant ones
key_chunks = input_ids[:, :-n_window_chunk + 1]
if self.add_key_continuation:
continuation_chunks = input_ids[:, 1: -n_window_chunk + 2]
key_chunks = torch.cat([key_chunks, continuation_chunks], dim=-1)
# value chunks extend key chunks by one chunk because we may want to splice out the continuation chunk of the last key
value_chunks = input_ids[:, :-n_window_chunk + 2]
labels_mask_indices_offset = 0
elif self.integrate_method == "concat":
fixed_context = input_ids[:, -n_window_chunk:]
key_chunks = input_ids[:, :-n_window_chunk - 1]
if self.add_key_continuation:
continuation_chunks = input_ids[:, 1: -n_window_chunk]
key_chunks = torch.cat([key_chunks, continuation_chunks], dim=-1)
value_chunks = input_ids[:, :-n_window_chunk]
labels = torch.cat([labels.new_zeros(batch_size, 2 * self.key_num * self.chunk_size) - 100, labels], dim=-1)
labels_mask_indices_offset = 2 * self.key_num * self.chunk_size
else:
raise NotImplementedError(f"Integration strategy {self.integrate_method} not implemented!")
fixed_context = fixed_context.flatten(-2, -1)
# Prepare labels mask to be used in sub-batch
# Each query chunk will produce a sample, but only its next chunk should be evaluated
n_query_chunk = query_chunks.shape[1]
n_key_chunk = key_chunks.shape[1]
target_chunk_start_idx = n_window_chunk - n_target_chunk
# How many tokens in total until i-th chunk
bias = torch.arange(n_query_chunk, device=input_ids.device) * self.chunk_size
# Inside each chunk, the indices start from 0 to chunk_size - 1
# add target_chunk_start_idx because we want the labels computed
# only for target chunks
arange = torch.arange(self.chunk_size, device=input_ids.device) + target_chunk_start_idx * self.chunk_size
labels_mask_indices = bias[:, None] + arange[None, :]
labels_mask_indices = labels_mask_indices.view(n_query_chunk, self.chunk_size) + labels_mask_indices_offset
if self.retrieval_method == "dense":
# Encode queries and keys
queries = self.tokenizer.batch_decode(query_chunks.flatten(0, 1), skip_special_tokens=True)
keys = self.tokenizer.batch_decode(key_chunks.flatten(0, 1), skip_special_tokens=True)
if self.instruction is not None:
queries = [self.instruction["query"] + q for q in queries]
keys = [self.instruction["key"] + k for k in keys]
# The retriever automatically does truncation and padding
query_embeddings = self.retriever.encode(queries).view(batch_size, n_query_chunk, -1)
key_embeddings = self.retriever.encode(keys).view(batch_size, n_key_chunk, -1)
elif self.retrieval_method == "random":
pass
elif self.retrieval_method == "bm25":
bm25_indexes = []
for i in range(batch_size):
bm25 = deepcopy(self.retriever)
bm25.index(key_chunks[i].tolist())
bm25_indexes.append(bm25)
elif self.retrieval_method == "oracle":
assert self.key_num == 1 and batch_size == 1, f"Retrieval_method 'oracle' is only available when k == 1 and batch_size == 1!"
all_losses = 0
all_valid_tokens = 0
# enumerate all chunks
for i in range(n_query_chunk):
min_k = 0
min_loss = 1e3
min_retrieved_chunks = None
min_input_ids = None
sub_labels = labels # batch_size, n_window_chunk * self.chunk_size
sub_labels_mask = torch.ones_like(sub_labels, dtype=torch.bool)
sub_labels_mask.scatter_(dim=-1, index=labels_mask_indices[None, i].expand(batch_size, -1), value=False)
sub_labels = sub_labels.masked_fill(sub_labels_mask, -100)
# NOTE: the loss is averaged over valid tokens, thus we must store the valid token number for the final computation
valid_tokens = (sub_labels != -100).sum()
for k in range(n_key_chunk):
retrieved_chunks = value_chunks[:, k: k+2] # batch_size, 2, chunk_size
retrieved_chunks = retrieved_chunks.flatten(-2, -1)
if self.add_sep is not None:
retrieved_chunks[..., -len(self.sep_token_ids):] = self.sep_token_ids
sub_input_ids = torch.cat([retrieved_chunks, fixed_context], dim=-1)
sub_attention_mask = (sub_input_ids != self.tokenizer.pad_token_id).long()
outputs = self.model(input_ids=sub_input_ids, attention_mask=sub_attention_mask, labels=sub_labels)
if (sub_labels == -100).all():
# NOTE: in this case, the model will return nan. We correct its behavior by returning 0
loss = 0
else:
loss = outputs.loss
if loss < min_loss:
min_loss = loss
min_k = k
min_retrieved_chunks = retrieved_chunks
min_input_ids = sub_input_ids
if self.debug_retrieval:
print("-"*50)
context = fixed_context.unflatten(-1, (-1, self.chunk_size))
print(min_loss)
print(f"***Indices***\n{min_k}")
print(f"***Query***\n{repr(self.tokenizer.decode(query_chunks[0, i]))}")
print(f"***Target***\n{repr(self.tokenizer.decode(context[0, -n_target_chunk]))}")
print(f"***Retrieved***\n{repr(self.tokenizer.decode(min_retrieved_chunks[0]))}")
print(f"***Inputs***\n{repr(self.tokenizer.decode(min_input_ids[0]))}")
print(f"***Labels***\n{repr(self.tokenizer.decode(sub_labels.masked_fill(sub_labels_mask, self.tokenizer.pad_token_id)[0]))}")
print()
input()
all_losses += min_loss * valid_tokens
all_valid_tokens += valid_tokens
loss = all_losses / all_valid_tokens
return SRLMOutput(loss=loss)
else:
raise NotImplementedError(f"Retrieval method {self.retrieval_method} not implemented!")
# Compute language modeling loss for each target chunk in sub-batch
all_losses = None
all_valid_tokens = 0
for i in range(0, n_query_chunk, self.chunk_batch_size):
j = min(i + self.chunk_batch_size, n_query_chunk)
chunk_batch_size = j - i
if self.retrieval_method == "dense":
query_embedding = query_embeddings[:, i: j] # batch_size, chunk_batch_size, d_embed
rel_score = torch.einsum("bid,bjd->bij", query_embedding, key_embeddings) # batch_size, chunk_batch_size, n_key_chunk
retrieved_indices = rel_score.topk(self.key_num, dim=-1)[1].flatten(0, 1) # batch_size * chunk_batch_size, k
elif self.retrieval_method == "random":
retrieved_indices = torch.randint(0, n_key_chunk, (batch_size * chunk_batch_size, self.key_num), device=input_ids.device)
elif self.retrieval_method == "bm25":
retrieved_indices = torch.zeros(batch_size, chunk_batch_size, self.key_num, dtype=torch.long, device=value_chunks.device)
for batch_idx in range(batch_size):
query_chunk = query_chunks[batch_idx, i: j].tolist()
_, indice = bm25_indexes[batch_idx].search(query_chunk, hits=self.key_num)
retrieved_indices[batch_idx] = torch.from_numpy(indice)
retrieved_indices = retrieved_indices.flatten(0, 1)
# batch_size * chunk_batch_size, k * 2 * chunk_size
retrieved_chunks, retrieved_indices = self._get_retrieved_chunks(value_chunks, retrieved_indices)
# Each sub-batch has its own retrieved contexts
sub_input_ids = torch.cat([retrieved_chunks, fixed_context.repeat_interleave(chunk_batch_size, dim=0)], dim=-1)
sub_attention_mask = (sub_input_ids != self.tokenizer.pad_token_id).long()
# NOTE: here we donot add position_ids to keep the outputs exactly the same as the default behavior
# position_ids = attention_mask.cumsum(-1) - 1
# position_ids.masked_fill_(attention_mask == 0, 0)
# repeat labels across sub-batch
sub_labels = labels.repeat_interleave(chunk_batch_size, dim=0) # batch_size * chunk_batch_size, n_window_chunk * self.chunk_size
sub_labels_mask = torch.ones_like(sub_labels, dtype=torch.bool)
# NOTE: only compute loss for this sub-batch
sub_labels_mask.scatter_(dim=-1, index=labels_mask_indices[None, i: j].expand(batch_size, -1, -1).flatten(0, 1), value=False)
sub_labels = sub_labels.masked_fill(sub_labels_mask, -100)
if self.debug_retrieval:
print("-"*50)
context = fixed_context.unflatten(-1, (-1, self.chunk_size))
indices = retrieved_indices.unflatten(0, (batch_size, chunk_batch_size))
chunks = retrieved_chunks.view(batch_size, chunk_batch_size, self.key_num, 2 * self.chunk_size)
for r in range(chunk_batch_size):
idx = r + i
print(f"***Indices***\n{indices[0, r]}")
print(f"***Query***\n{repr(self.tokenizer.decode(query_chunks[0, idx]))}")
print(f"***Target***\n{repr(self.tokenizer.decode(context[0, -n_target_chunk + idx]))}")
print(f"***Retrieved***\n{repr(self.tokenizer.batch_decode(chunks[0, r]))}")
print(f"***Inputs***\n{repr(self.tokenizer.batch_decode(sub_input_ids))}")
print(f"***Labels***\n{repr(self.tokenizer.batch_decode(sub_labels.masked_fill(sub_labels_mask, self.tokenizer.pad_token_id)))}")
print()
input()
outputs = self.model(input_ids=sub_input_ids, attention_mask=sub_attention_mask, labels=sub_labels)
if (sub_labels == -100).all():
# NOTE: in this case, the model will return nan. We correct its behavior by returning 0
loss = 0
else:
loss = outputs.loss
# NOTE: the loss is averaged over valid tokens, thus we must store the valid token number for the final computation
valid_tokens = (sub_labels != -100).sum()
if all_losses is None:
all_losses = loss * valid_tokens
else:
all_losses += loss * valid_tokens
all_valid_tokens += valid_tokens
loss = all_losses / all_valid_tokens
return SRLMOutput(loss=loss)
@torch.no_grad()
def compute_perplexity(self, dataloader):
"""
Compute perplexity over long inputs
"""
self.model.eval()
all_nlls = []
for step, inputs in enumerate(tqdm(dataloader, desc='Computing Perplexity')):
# if step > 5:
# break
# move to gpu
inputs = self._move_to_device(inputs)
outputs = self(**inputs)
nll = outputs.loss
if self.accelerator is not None:
# mean nlls from all processes
nll = self.accelerator.gather_for_metrics(nll).mean()
all_nlls.append(nll.tolist())
all_nlls = sum(all_nlls) / len(all_nlls)
perplexity = math.exp(all_nlls)
return perplexity
# TODO
# def generate(self, input_ids, attention_mask, **kwds):
# """Generate by chunks"""
# generation_config = self.model.generation_config
# assert generation_config.max_new_tokens is not None, f"Make sure the max_new_tokens parameter in model's generation_config is not None!"
# global_max_new_tokens = generation_config.max_new_tokens
# n_generate_chunk = global_max_new_tokens // self.chunk_size
# batch_size = input_ids.shape[0]
# assert input_ids.shape[1] % self.chunk_size == 0, f"Make sure the generation input length {input_ids.shape[1]} is divisible by chunk size!"
# # 1. Encode
# n_window_chunk = input_ids.shape[1] // self.chunk_size
# # concatenate extra context
# if prev_input_ids is not None:
# assert prev_input_ids.shape[1] % self.chunk_size == 0, f"Make sure the prev input length {prev_input_ids.shape[1]} is divisible by chunk size!"
# input_ids = torch.cat([prev_input_ids, input_ids], dim=-1)
# input_ids = input_ids.view(batch_size, -1, self.chunk_size)
# key_chunks = input_ids[:, :-n_window_chunk + 2 * self.key_num - 1]
# value_chunks = input_ids[:, :-n_window_chunk + 2 * self.key_num]
# fixed_context = input_ids[:, -n_window_chunk + 2 * self.key_num:] # batch_size, n_window_chunk - 2 * k, chunk_size
# n_key_chunk = key_chunks.shape[1]
# keys = self.tokenizer.batch_decode(key_chunks.flatten(0, 1), skip_special_tokens=True)
# key_embeddings = self.encoder(keys).view(batch_size, n_key_chunk, -1)
# # 2. Generate by chunk
# for step in range(n_generate_chunk):
# query_chunk = fixed_context[:, -1:] # batch_size, 1, chunk_size
# query = self.tokenizer.batch_decode(query_chunk.squeeze(1), skip_special_tokens=True)
# query_embedding = self.encoder(query).view(batch_size, 1, -1)
# # Slice out the retrieved chunk and its continuation from the corpus
# retrieved_chunks, retrieved_indices = self._dense_retrieval(query_embedding, key_embeddings, value_chunks)
# if self.debug_retrieval:
# print("-"*50)
# indices = retrieved_indices.unflatten(0, (batch_size, 1))
# chunks = retrieved_chunks.view(batch_size, 1, self.key_num, 2 * self.chunk_size)
# print(f"***Indices***\n{indices[0, 0]}")
# print(f"***Query***\n{repr(self.tokenizer.decode(query_chunk[0, 0]))}")
# print(f"***Retrieved***\n{repr(self.tokenizer.batch_decode(chunks[0, 0]))}")
# print()
# input()
# step_input_ids = torch.cat([retrieved_chunks, fixed_context.flatten(-2, -1)], dim=-1)
# step_attention_mask = (step_input_ids != self.tokenizer.pad_token_id).long()
# # generate chunk_size tokens once
# kwds["max_new_tokens"] = self.chunk_size
# outputs = self.model.generate(input_ids=step_input_ids, attention_mask=step_attention_mask, **kwds) # batch_size, chunk_size
# # slice out the newly-generated tokens
# outputs = outputs[:, step_input_ids.shape[1]:] # batch_size, chunk_size
# assert outputs.shape[-1] == self.chunk_size
# fixed_context = torch.cat([fixed_context, outputs.unsqueeze(1)], dim=1) # batch_size, -, chunk_size
# # 3. Finalize. Set all tokens after the first eos token to pad token
# generated_tokens = torch.cat([input_ids[:, -n_window_chunk: -n_window_chunk + 2 * self.key_num:], fixed_context], dim=1).flatten(-2, -1) # batch_size, (n_window_chunk + n_generate_chunk) * chunk_size
# # is_eos = (generated_tokens == self.tokenizer.eos_token_id).float()
# # has_eos = (generated_tokens == self.tokenizer.eos_token_id).any(-1)
# # eos_start_index = is_eos.argmax(-1)
# # print(generated_tokens)
# # print(eos_start_index, has_eos)
# # for i, idx in enumerate(eos_start_index):
# # if has_eos[i]:
# # generated_tokens[i, idx + 1:] = self.tokenizer.pad_token_id
# return generated_tokens
from .args import RetrievalArgs, RankerArgs
from .modeling_dense import DenseRetriever
from .modeling_bm25 import BM25Retriever, NaiveBM25Retriever
from .modeling_unified import Retriever
from .modeling_ranker import CrossEncoder
from .metrics import RetrievalMetric
from .data import RetrievalDataset, RetrievalDataCollator, TASK_CONFIG
import os
from dataclasses import dataclass, field
from transformers.training_args import TrainingArguments
from typing import Optional, List, Union
@dataclass
class BaseArgs:
model_cache_dir: Optional[str] = field(
default=None,
metadata={'help': 'Default path to save language models.'}
)
dataset_cache_dir: Optional[str] = field(
default=None,
metadata={'help': 'Default path to save huggingface datasets.'}
)
data_root: str = field(
default="/data/llm-embedder",
metadata={'help': 'The base directory storing all data used for training and evaluation. If specified, make sure all train_data, eval_data, and corpus are path relative to data_root!'},
)
train_data: Optional[List[str]] = field(
default=None,
metadata={'help': 'Training json file or glob to match a list of files.'},
)
eval_data: Optional[str] = field(
default=None,
metadata={'help': 'Evaluation json file.'},
)
corpus: str = field(
default=None,
metadata={'help': 'Corpus jsonl file.'}
)
key_template: str = field(
default="{title} {text}",
metadata={'help': 'How to concatenate columns in the corpus to form one key?'}
)
metrics: List[str] = field(
default_factory=lambda: ["mrr", "recall", "ndcg"],
metadata={'help': 'List of metrics'}
)
cutoffs: List[int] = field(
default_factory=lambda: [1, 5, 10, 100],
metadata={'help': 'Cutoffs to evaluate retrieval metrics.'}
)
filter_answers: bool = field(
default=False,
metadata={'help': 'Remove negatives that contain the desired answer when collating negatives?'}
)
max_neg_num: int = field(
default=100,
metadata={'help': 'Maximum negative number to mine.'}
)
load_result: bool = field(
default=False,
metadata={'help': 'Load retrieval results directly?'}
)
save_result: bool = field(
default=True,
metadata={'help': 'Save retrieval results?'}
)
save_name: Optional[str] = field(
default=None,
metadata={'help': 'Name suffix of the json file when saving the collated retrieval results.'}
)
save_to_output: bool = field(
default=False,
metadata={'help': 'Save the result/key/negative to output_dir? If not true, they will be saved next to the eval_data.'}
)
def resolve_path(self, path):
"""Resolve any path starting with 'llm-embedder:' to relative path against data_root."""
pattern = "llm-embedder:"
# resolve relative data paths when necessary
if isinstance(path, list):
for i, x in enumerate(path):
if x.startswith(pattern):
path[i] = os.path.join(self.data_root, x.replace(pattern, ""))
else:
if path.startswith(pattern):
path = os.path.join(self.data_root, path.replace(pattern, ""))
return path
def __post_init__(self):
if self.train_data is not None:
self.train_data = self.resolve_path(self.train_data)
if self.eval_data is not None:
self.eval_data = self.resolve_path(self.eval_data)
if self.corpus is not None:
self.corpus = self.resolve_path(self.corpus)
@dataclass
class DenseRetrievalArgs(BaseArgs):
query_encoder: str = field(
default="BAAI/bge-base-en",
metadata={'help': 'Path to encoder model or model identifier from huggingface.co/models.'}
)
key_encoder: str = field(
default="BAAI/bge-base-en",
metadata={'help': 'Path to encoder model or model identifier from huggingface.co/models.'}
)
add_instruction: bool = field(
default=True,
metadata={'help': 'Add instruction for each task?'}
)
version: str = field(
default="bge",
metadata={'help': 'Version for configs.'}
)
query_max_length: int = field(
default=256,
metadata={'help': 'Max query length.'}
)
key_max_length: int = field(
default=256,
metadata={'help': 'Max key length.'}
)
truncation_side: str = field(
default="right",
metadata={'help': 'Which side to truncate?'}
)
pooling_method: List[str] = field(
default_factory=lambda: ["cls"],
metadata={'help': 'Pooling methods to aggregate token embeddings for a sequence embedding. {cls, mean, dense, decoder}'}
)
tie_encoders: bool = field(
default=True,
metadata={'help': 'Tie query encoder and key encoder? If True, then the query_encoder_name is used.'}
)
dense_metric: str = field(
default="cos",
metadata={'help': 'What type of metric for dense retrieval? ip, l2, or cos.'}
)
faiss_index_factory: str = field(
default="Flat",
metadata={'help': 'Index factory string for faiss.'}
)
hits: int = field(
default=200,
metadata={'help': 'How many keys to retrieve?'}
)
batch_size: int = field(
default=1000,
metadata={'help': 'Batch size for indexing and retrieval.'}
)
load_encode: bool = field(
default=False,
metadata={'help': 'Load cached embeddings?'}
)
save_encode: bool = field(
default=False,
metadata={'help': 'Save embeddings?'}
)
load_index: bool = field(
default=False,
metadata={'help': 'Load cached index?'}
)
save_index: bool = field(
default=False,
metadata={'help': 'Save index?'}
)
embedding_name: str = field(
default="embeddings",
metadata={'help': 'The embedding name for saving? (Also used for faiss index name.)'}
)
dtype: str = field(
default="fp16",
metadata={'help': 'Data type for retriever.'}
)
cpu: bool = field(
default=False,
metadata={'help': 'Use cpu?'}
)
@dataclass
class BM25Args(BaseArgs):
anserini_dir: str = field(
default='/share/peitian/Apps/anserini',
metadata={'help': 'Anserini installation directory.'}
)
k1: float = field(
default=0.82,
metadata={'help': 'BM25 k1.'}
)
b: float = field(
default=0.68,
metadata={'help': 'BM25 b.'}
)
storeDocvectors: bool = field(
default=False,
metadata={'help': 'Store document vector? Useful when you want to inspect the word-level statistics (tf-idf) after index construction.'}
)
hits: int = field(
default=200,
metadata={'help': 'How many keys to retrieve?'}
)
language: str = field(
default="en",
metadata={'help': 'Language.'}
)
threads: int = field(
default=32,
metadata={'help': 'Indexing/Searching thread number.'}
)
load_index: bool = field(
default=False,
metadata={'help': 'Load index?'}
)
load_collection: bool = field(
default=False,
metadata={'help': 'Load collection?'}
)
@dataclass
class RankerArgs(BaseArgs):
ranker: str = field(
default="BAAI/bge-base-en",
metadata={'help': 'Ranker name or path.'}
)
ranker_method: str = field(
default="cross-encoder",
metadata={'help': 'What kind of ranker to use? {cross: cross encoder}'}
)
dtype: str = field(
default="fp16",
metadata={'help': 'Data type for ranker.'}
)
query_max_length: int = field(
default=256,
metadata={'help': 'Max query length.'}
)
key_max_length: int = field(
default=256,
metadata={'help': 'Max key length.'}
)
add_instruction: bool = field(
default=False,
metadata={'help': 'Add instruction for each task?'}
)
version: str = field(
default="bge",
metadata={'help': 'Version for configs.'}
)
hits: Optional[int] = field(
default=None,
metadata={'help': 'How many top reranked keys to keep?'}
)
batch_size: int = field(
default=4,
metadata={'help': 'Batch size for indexing and retrieval.'}
)
cpu: bool = field(
default=False,
metadata={'help': 'Use cpu?'}
)
@dataclass
class RetrievalArgs(DenseRetrievalArgs, BM25Args):
retrieval_method: str = field(
default="dense",
metadata={'help': 'How to retrieve? {dense, bm25, random, no}'}
)
@dataclass
class RetrievalTrainingArgs(TrainingArguments):
output_dir: str = field(
default='data/outputs/',
metadata={'help': 'The output directory where the model predictions and checkpoints will be written.'},
)
eval_method: str = field(
default="retrieval",
metadata={'help': 'How to evaluate?'},
)
use_train_config: bool = field(
default=False,
metadata={'help': 'Use training config from TASK_CONFIG to override arguments?'}
)
inbatch_same_dataset: Optional[str] = field(
default=None,
metadata={'help': 'Whether and how to use samples from the same task in each batch (across devices). {epoch, random}'}
)
negative_cross_device: bool = field(
default=True,
metadata={'help': 'Gather negatives from all devices when distributed training?'}
)
cos_temperature: float = field(
default=0.01,
metadata={'help': 'Temperature used for cosine dense metric.'}
)
teacher_temperature:float = field(
default=1.,
metadata={'help': 'Temperature used for cosine dense metric.'}
)
student_temperature:float = field(
default=1.,
metadata={'help': 'Temperature used for cosine dense metric.'}
)
contrastive_weight: float = field(
default=0.2,
metadata={'help': 'Weight for contrastive loss.'}
)
distill_weight: float = field(
default=1.0,
metadata={'help': 'Weight for distillation loss.'}
)
stable_distill: bool = field(
default=False,
metadata={'help': 'Sort distillation.'}
)
max_sample_num: Optional[int] = field(
default=None,
metadata={'help': 'How many samples at most for training dataset?'}
)
train_group_size: int = field(
default=8,
metadata={'help': 'How many keys in a batch?'}
)
select_positive: str = field(
default="first",
metadata={'help': 'How to select the positive key from a set of positives?'}
)
select_negative: str = field(
default="random",
metadata={'help': 'How to select the negative keys from a set of negatives?'}
)
teacher_scores_margin: Optional[float] = field(
default=None,
metadata={'help': 'Minimum margin in teacher_scores. The samples with smaller margin will be removed from training.'}
)
teacher_scores_min: Optional[float] = field(
default=None,
metadata={'help': 'Minimum teacher_scores. The samples whose biggest score is lower than this will be removed from training.'}
)
per_device_train_batch_size: int = field(
default=16,
metadata={'help': 'Train batch size'},
)
learning_rate: float = field(
default=5e-6,
metadata={'help': 'Learning rate.'},
)
warmup_ratio: float = field(
default=0.1,
metadata={'help': 'Warmup ratio for linear scheduler.'},
)
weight_decay: float = field(
default=0.01,
metadata={'help': 'Weight decay in AdamW.'},
)
fp16: bool = field(
default=True,
metadata={'help': 'Use fp16 training?'}
)
ddp_find_unused_parameters: bool = field(
default=False,
metadata={'help': 'Find unused parameters in torch DDP?'},
)
remove_unused_columns: bool = field(
default=False,
metadata={'help': 'Remove columns that are not registered in the forward function of the model?'},
)
evaluation_strategy: str = field(
default='steps',
metadata={'help': 'Evaluation strategy'},
)
save_steps: int = field(
default=2000,
metadata={'help': 'Saving frequency.'},
)
logging_steps: int = field(
default=100,
metadata={'help': 'Logging frequency according to logging strategy.'},
)
early_exit_steps: Optional[int] = field(
default=None,
metadata={'help': 'After how many steps to exit training loop.'},
)
report_to: str = field(
default="none", metadata={"help": "The list of integrations to report the results and logs to."}
)
log_path: str = field(
default="data/results/performance.log",
metadata={'help': 'Pooling method to aggregate token embeddings for a sequence embedding.'}
)
# NOTE: newer version of transformers forbid modifying the configs after initilization, we bypass this setting
def __setattr__(self, name, value):
super(TrainingArguments, self).__setattr__(name, value)
def __post_init__(self):
super().__post_init__()
# for convenience
# self.eval_steps = self.save_steps
import math
import torch
import random
import datasets
import numpy as np
from glob import glob
from string import Formatter
from typing import Optional, Tuple, Union, List, Callable, Dict, Any, Mapping
from copy import deepcopy
from dataclasses import dataclass
from collections import defaultdict
from transformers.tokenization_utils import PreTrainedTokenizer
from ..utils.util import get_max_length_in_nested_lists, pad_nested_lists, split_file_dir_name_ext, DatasetProcessFn
class RetrievalDataset:
def get_train_process_fn(train_group_size=8, select_positive="first", select_negative="random", teacher_scores_margin=None, teacher_scores_min=None, stable_distill=False, instruction=None):
@DatasetProcessFn()
def _process(query:str, task:str, pos:List[str]=None, neg:List[str]=None, history:List[str]=None, teacher_scores:Optional[List[float]]=None, **kwds):
output = {}
keys = []
if history is not None:
pos = []
neg = history
# filter based on teacher scores
if teacher_scores is not None:
assert len(teacher_scores) == len(pos) + len(neg), f"Found incompatible teacher_score size ({len(teacher_scores)}) and positive size ({len(pos)}) negative size ({len(neg)})"
if teacher_scores_min is not None:
max_score = max(teacher_scores)
if max_score < teacher_scores_min:
return None
if teacher_scores_margin is not None:
max_score = max(teacher_scores)
min_score = min(teacher_scores)
if max_score - min_score < teacher_scores_margin:
return None
pos_num = len(pos)
if select_positive == "random":
assert pos_num > 0, f"Select positive strategy 'random' is only available when there is a given positive!"
pos_idx = random.choice(range(pos_num))
pos = pos[pos_idx]
elif teacher_scores is not None and select_positive == "teacher":
pos_idx = max(enumerate(teacher_scores), key=lambda x: x[1])[0]
if pos_idx < pos_num:
pos = pos[pos_idx]
else:
# pos is selected from neg, thus we remove it from neg
pos = neg.pop(pos_idx - pos_num)
elif teacher_scores is not None and select_positive == "teacher-pos":
assert pos_num > 0, f"Select positive strategy 'teacher-pos' is only available when there are teacher_scores and positives!"
pos_scores = teacher_scores[:pos_num]
pos_idx = max(enumerate(pos_scores), key=lambda x: x[1])[0]
pos = pos[pos_idx]
else:
# NOTE: default to select the first positive
assert pos_num > 0, f"Select positive strategy 'first' is only available when there is a given positive!"
pos_idx = 0
pos = pos[0]
if teacher_scores is not None:
if pos_idx >= pos_num:
# only makes sense when select_positive==teacher
# remove the selected score
pos_score = teacher_scores.pop(pos_idx)
else:
pos_score = teacher_scores[pos_idx]
# remove teacher scores of unused positives
neg_scores = teacher_scores[pos_num:]
return_teacher_scores = [pos_score]
keys.append(pos)
if len(neg) == 0:
return None
elif len(neg) < train_group_size - 1:
num = math.ceil((train_group_size - 1) / len(neg))
neg = neg * num
if teacher_scores is not None:
neg_scores = neg_scores * num
if teacher_scores is not None and select_negative == "teacher-":
neg_indices = [i for i, _ in sorted(enumerate(neg_scores), key=lambda x: x[1])[:train_group_size - 1]]
elif teacher_scores is not None and select_negative == "teacher+":
neg_indices = [i for i, _ in sorted(enumerate(neg_scores), key=lambda x: x[1], reverse=True)[:train_group_size - 1]]
elif select_negative == "first":
neg_indices = list(range(len(neg)))[:train_group_size - 1]
else:
# NOTE: default to select random negatives
neg_indices = random.sample(range(len(neg)), train_group_size - 1)
for neg_idx in neg_indices:
keys.append(neg[neg_idx])
if teacher_scores is not None:
return_teacher_scores.append(neg_scores[neg_idx])
if instruction is not None:
query = instruction["query"] + query
keys = [instruction["key"] + key for key in keys]
output = {
"query": query,
"key": keys,
"task": task,
}
if teacher_scores is not None:
output["teacher_scores"] = return_teacher_scores
if stable_distill:
# when using stable_distill, we must sort teacher_scores descendingly
neg_score = output["teacher_scores"][1:]
neg = output["key"][1:]
pairs = sorted(list(zip(neg, neg_score)), key=lambda x: x[1], reverse=True)
neg = [pair[0] for pair in pairs]
neg_score = [pair[1] for pair in pairs]
output["key"][1:] = neg
output["teacher_scores"][1:] = neg_score
return output
return _process
def prepare_train_dataset(data_file=None, cache_dir=None, config=None, train_group_size=8, select_positive="first", select_negative="random", max_sample_num=None, teacher_scores_margin=None, teacher_scores_min=None, stable_distill=False, add_instruction=False, instruction=None, use_train_config=False):
if data_file is None:
return None, None
if isinstance(data_file, str):
if "*" in data_file:
data_file = glob(data_file)
else:
data_file = [data_file]
train_datasets = []
offset = 0
dataset_indices_range = {}
dataset_dup = defaultdict(int)
for path in data_file:
temp_dataset = datasets.load_dataset('json', data_files=path, split='train', cache_dir=cache_dir)
task = temp_dataset[0]["task"]
directory, _, _ = split_file_dir_name_ext(path)
dataset_name = directory.name
if add_instruction:
instruction = config["instruction"][task]
if use_train_config:
train_config = config["training"][task]
select_positive = train_config["select_positive"]
select_negative = train_config["select_negative"]
max_sample_num = train_config["max_sample_num"]
teacher_scores_margin = train_config["teacher_scores_margin"]
teacher_scores_min = train_config["teacher_scores_min"]
stable_distill = train_config["stable_distill"]
process_fn = RetrievalDataset.get_train_process_fn(
train_group_size,
select_positive=select_positive,
select_negative=select_negative,
teacher_scores_margin=teacher_scores_margin,
teacher_scores_min=teacher_scores_min,
stable_distill=stable_distill,
instruction=instruction
)
# map to filter
temp_dataset = temp_dataset.map(process_fn, batched=True, num_proc=32, remove_columns=temp_dataset.column_names)
# limit sample number
if max_sample_num is not None and len(temp_dataset) > max_sample_num:
temp_dataset = temp_dataset.train_test_split(max_sample_num, shuffle=False)["test"]
train_datasets.append(temp_dataset)
if dataset_name in dataset_indices_range:
# NOTE: we allow duplicated dataset to balance the portion of different datasets
dataset_dup[dataset_name] += 1
dataset_indices_range[f"{dataset_name}_{dataset_dup[dataset_name]}"] = (offset, offset + len(temp_dataset))
else:
dataset_indices_range[dataset_name] = (offset, offset + len(temp_dataset))
offset += len(temp_dataset)
dataset = datasets.concatenate_datasets(train_datasets)
return dataset, dataset_indices_range
@staticmethod
def prepare_eval_dataset(data_file=None, cache_dir=None, instruction=None, eval_method="retrieve"):
if data_file is None:
return None
@DatasetProcessFn()
def _process(query:str, query_id:Optional[int]=None, key:Optional[List[str]]=None, key_index: Optional[List[int]]=None, pos: Optional[List[Union[int, str]]]=None, neg: Optional[List[str]]=None, pos_index:Optional[List[int]]=None, neg_index: Optional[List[int]]=None, _index=None, **kwds):
if instruction is not None:
query = instruction["query"] + query
if query_id is None:
assert _index is not None
query_id = _index
output = {
"query": query,
"query_id": query_id,
"task": task,
}
if eval_method == "rerank":
# if there is a column named key, it must be the candidates to rerank
if key is not None:
if key_index is not None:
output["key_index"] = key_index
else:
# NOTE: there must be key_index when reranking
output["key_index"] = list(range(len(key)))
# otherwise, default
elif pos is not None and neg is not None:
key = pos + neg
if pos_index is not None:
output["key_index"] = pos_index + neg_index
else:
# NOTE: there must be key_index when reranking
output["key_index"] = list(range(len(key)))
else:
raise ValueError(f"Expected either pos/neg or key in the file {data_file}!")
if instruction is not None:
output["key"] = [instruction["key"] + k for k in key]
else:
output["key"] = key
return output
dataset = datasets.load_dataset('json', data_files=data_file, split='train', cache_dir=cache_dir)
if "task" in dataset:
task = dataset[0]["task"]
else:
task = "nan"
dataset = dataset.map(_process, num_proc=32, batched=True, remove_columns=dataset.column_names, with_indices=True)
return dataset
@staticmethod
def prepare_corpus(data_file, key_template:str, cache_dir=None, instruction=None):
"""Concatenate desired keys by key_template"""
if data_file is None:
return None
keys = Formatter().parse(key_template)
field_names = [x[1] for x in keys if x[1] is not None]
@DatasetProcessFn()
def _process(**kwds):
inputs = {name: kwds[name] for name in field_names}
content = key_template.format(**inputs)
if instruction is not None:
content = instruction["key"] + content
return {'content': content}
dataset = datasets.load_dataset('json', data_files=data_file, split="train", cache_dir=cache_dir)
dataset.set_transform(_process)
return dataset
class SameDatasetTrainDataset(torch.utils.data.Dataset):
"""Dataset to yield a batch of data at one time. All samples in the same batch comes from the same task.
Args:
organize_method:
random:
epoch:
epoch-random:
epoch-static
"""
def __init__(self, dataset, dataset_indices_range, batch_size, seed, organize_method, process_index=0, num_processes=1):
self.dataset = dataset
self.batch_size = batch_size
self.organize_method = organize_method
self.process_index = process_index
self.num_processes = num_processes
self.dataset_indices_range = dataset_indices_range
self.deterministic_generator = np.random.default_rng(seed)
# different devices must sample different data batch
self.nondeterministic_generator = np.random.default_rng(seed + process_index)
# shuffle the indices
if "random" in self.organize_method:
self.sample_range = [np.arange(*x) for x in self.dataset_indices_range.values()]
for x in self.sample_range:
# NOTE: we must make sure every processes use the same shuffling order
self.deterministic_generator.shuffle(x)
def create_epoch(self):
epoch = []
for k, x in self.dataset_indices_range.items():
dataset_range = np.arange(*x)
# NOTE: we must make sure every processes use the same shuffling order
self.deterministic_generator.shuffle(dataset_range)
num_batches, remainer = divmod(len(dataset_range), self.batch_size * self.num_processes)
# Truncate
if remainer != 0:
dataset_range = dataset_range[:num_batches * self.batch_size * self.num_processes]
batches = dataset_range.reshape(num_batches, self.batch_size * self.num_processes).tolist()
for i in range(len(batches)):
batches[i] = (k, batches[i])
epoch.extend(batches)
# shuffle among datasets, also make sure different processes share the same shuffling results
self.deterministic_generator.shuffle(epoch)
self.epoch = epoch
self.step = 0
self.steps_per_epoch = len(epoch)
def __getitem__(self, idx):
if self.organize_method == "random":
sample_prob = [len(x) / len(self.dataset) for x in self.sample_range]
dataset_name = self.deterministic_generator.choice(range(len(self.sample_range)), size=1, p=sample_prob)[0]
sample_range = self.sample_range[dataset_name]
batch_indices = self.nondeterministic_generator.choice(sample_range, size=self.batch_size, replace=False)
batch_data = self.dataset[batch_indices.tolist()]
elif self.organize_method == "epoch":
if not hasattr(self, "epoch") or self.step > self.steps_per_epoch - 1:
self.create_epoch()
dataset_name, batch_indices = self.epoch[self.step]
batch_indices = batch_indices[self.process_index * self.batch_size: (self.process_index + 1) * self.batch_size]
batch_data = self.dataset[batch_indices]
self.step += 1
elif self.organize_method == "epoch-static":
if not hasattr(self, "epoch"):
# the data within each batch is static once created
self.create_epoch()
if self.step > self.steps_per_epoch - 1:
self.deterministic_generator.shuffle(self.epoch)
self.step = 0
dataset_name, batch_indices = self.epoch[self.step]
batch_indices = batch_indices[self.process_index * self.batch_size: (self.process_index + 1) * self.batch_size]
batch_data = self.dataset[batch_indices]
self.step += 1
elif self.organize_method == "epoch-random":
sample_scope = [len(x) for x in self.sample_range]
sample_prob = [x / sum(sample_scope) for x in sample_scope]
dataset_name = self.deterministic_generator.choice(range(len(self.sample_range)), size=1, p=sample_prob)[0]
sample_range = self.sample_range[dataset_name]
# sequential sample (the indices are already shuffled)
batch_indices = sample_range[self.process_index * self.batch_size: (self.process_index + 1) * self.batch_size]
batch_data = self.dataset[batch_indices.tolist()]
# update indices
remaining_indices = sample_range[self.num_processes * self.batch_size:]
if len(remaining_indices) < self.batch_size * self.num_processes:
remaining_indices = np.array([])
self.sample_range[dataset_name] = remaining_indices
# restore all indices if they are all sampled
if all(len(x) == 0 for x in self.sample_range):
self.sample_range = [np.arange(*x) for x in self.dataset_indices_range.values()]
for x in self.sample_range:
self.deterministic_generator.shuffle(x)
else:
raise NotImplementedError(f"Organize method {self.organize_method} is not implemented for SameTaskTrainDataset!")
return batch_data
def __len__(self):
return len(self.dataset) // self.batch_size
@dataclass
class RetrievalDataCollator:
"""
"""
tokenizer: PreTrainedTokenizer = None
query_max_length: int = 256
key_max_length: int = 256
inbatch_same_dataset: bool = False
cross: bool = False
def __call__(self, batch_elem):
first_elem = batch_elem[0]
return_batch = {}
for k, v in first_elem.items():
if self.inbatch_same_dataset:
# here the data have already been grouped
batch_value = batch_elem[0][k]
else:
batch_value = [elem[k] for elem in batch_elem]
# collate training/evaluating
if k == "query":
query = batch_value
# NOTE: we do not need the individual query and key when requiring cross data
if self.cross:
continue
batch_value = self.tokenizer(
batch_value,
padding=True,
truncation=True,
max_length=self.query_max_length,
return_tensors="pt",
)
elif k == "key":
# in case the keys are of different sizes for different queries when reranking
max_length = get_max_length_in_nested_lists(batch_value)
batch_value, key_mask = pad_nested_lists(batch_value, max_length, "", "right")
batch_value = sum(batch_value, [])
key = batch_value
# key_mask assigns 1 to valid keys and 0 to padded keys
return_batch["key_mask"] = torch.tensor(key_mask)
# NOTE: we do not need the individual query and key when requiring cross data
if self.cross:
continue
batch_value = self.tokenizer(
batch_value,
padding=True,
truncation=True,
max_length=self.key_max_length,
return_tensors="pt",
)
elif k == "key_index":
max_length = get_max_length_in_nested_lists(batch_value)
batch_value, _ = pad_nested_lists(batch_value, max_length, -1, "right")
batch_value = torch.tensor(batch_value)
elif k == "content":
# collate corpus
batch_value = self.tokenizer(
batch_value,
padding=True,
truncation=True,
max_length=self.key_max_length,
return_tensors="pt",
)
elif k == "task":
assert all(v == batch_value[0] for v in batch_value), f"Make sure all samples are of the same task in a batch!"
batch_value = batch_value[0]
elif all(v is None for v in batch_value):
# in case that some data have teacher_scores but others do not
batch_value = None
else:
batch_value = torch.tensor(batch_value)
return_batch[k] = batch_value
if self.cross:
query_num = len(query)
key_num = len(key)
assert key_num % query_num == 0
group_size = key_num // query_num
new_query = []
for i in range(key_num):
new_query.append(query[i // group_size])
return_batch["cross"] = self.tokenizer(
new_query, key,
padding=True,
truncation=True,
max_length=self.key_max_length + self.query_max_length,
return_tensors="pt"
)
return_batch["batch_size"] = len(query)
return return_batch
TASK_CONFIG = {
"llm-embedder": {
"instruction": {
"qa": {
"query": "Represent this query for retrieving relevant documents: ",
"key": "Represent this document for retrieval: ",
},
"convsearch": {
"query": "Encode this query and context for searching relevant passages: ",
"key": "Encode this passage for retrieval: ",
},
"chat": {
"query": "Embed this dialogue to find useful historical dialogues: ",
"key": "Embed this historical dialogue for retrieval: ",
},
"lrlm": {
"query": "Embed this text chunk for finding useful historical chunks: ",
"key": "Embed this historical text chunk for retrieval: ",
},
"icl": {
"query": "Convert this example into vector to look for useful examples: ",
"key": "Convert this example into vector for retrieval: ",
},
"tool": {
"query": "Transform this user request for fetching helpful tool descriptions: ",
"key": "Transform this tool description for retrieval: "
},
},
"training": {
"qa": {
"select_positive": "first",
"select_negative": "random",
"max_sample_num": None,
"teacher_scores_margin": None,
"teacher_scores_min": None,
"contrastive_weight": 0,
"stable_distill": True,
},
"convsearch": {
"select_positive": "first",
"select_negative": "random",
"max_sample_num": None,
"teacher_scores_margin": None,
"teacher_scores_min": None,
"distill_weight": 0,
"stable_distill": False,
},
"chat": {
"select_positive": "teacher",
"select_negative": "random",
"max_sample_num": None,
"teacher_scores_margin": None,
"teacher_scores_min": None,
"distill_weight": 1.0,
"contrastive_weight": 0,
"teacher_temperature": 0.1,
"stable_distill": False,
},
"lrlm": {
"select_positive": "teacher",
"select_negative": "random",
"max_sample_num": 10000,
"teacher_scores_margin": 0.1,
"teacher_scores_min": None,
"distill_weight": 1.0,
"contrastive_weight": 0,
"teacher_temperature": 0.1,
"stable_distill": False,
},
"icl": {
"select_positive": "random",
"select_negative": "random",
"max_sample_num": None,
"teacher_scores_margin": None,
"teacher_scores_min": None,
"contrastive_weight": 0,
"stable_distill": True,
},
"tool": {
"select_positive": "first",
"select_negative": "random",
"max_sample_num": None,
"teacher_scores_margin": None,
"teacher_scores_min": None,
"distill_weight": 0,
"stable_distill": False,
},
}
},
"bge": {
"instruction": defaultdict(lambda: {"query": "Represent this sentence for searching relevant passages: ", "key": ""})
},
"e5": {
"instruction": defaultdict(lambda: {"query": "query: ", "key": "passage: "})
},
"instructor": {
"instruction": {
"qa": {
"query": "Represent the query for retrieving supporting documents: ",
"key": "Represent the document for retrieval: ",
},
"convsearch": {
"query": "Represent the query and context for retrieving supporting passages: ",
"key": "Represent the passage for retrieval: ",
},
"chat": {
"query": "Represent the dialogue for retrieving useful historical dialogues: ",
"key": "Represent the historical dialogue for retrieval: ",
},
"lrlm": {
"query": "Represent the text chunk for retrieving useful historical chunks: ",
"key": "Represent the historical text chunk for retrieval: ",
},
"icl": {
"query": "Represent the example for retrieving duplicate examples: ",
"key": "Represent the example for retrieval: ",
},
"tool": {
"query": "Represent the user request for retrieving duplicate examples: ",
"key": "Represent the tool description for retrieval: "
},
},
}
}
import os
import datasets
import regex
import unicodedata
import numpy as np
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
class SimpleTokenizer:
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
def tokenize(self, text, uncase=False):
tokens = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Format data
if uncase:
tokens.append(token.lower())
else:
tokens.append(token)
return tokens
def _normalize(text):
return unicodedata.normalize('NFD', text)
def has_answer(answers, text, tokenizer) -> bool:
"""Check if a document contains an answer string.
"""
text = _normalize(text)
# Answer is a list of possible strings
text = tokenizer.tokenize(text, uncase=True)
for answer in answers:
answer = _normalize(answer)
answer = tokenizer.tokenize(answer, uncase=True)
for i in range(0, len(text) - len(answer) + 1):
if answer == text[i: i + len(answer)]:
return True
return False
class EvalDataset(Dataset):
def __init__(self, retrieval_result, eval_dataset, corpus):
self.corpus = corpus
self.eval_dataset = eval_dataset
self.retrieval_result = retrieval_result
self.tokenizer = SimpleTokenizer()
def __getitem__(self, qidx):
res = self.retrieval_result[qidx]
hits = []
for i, tidx in enumerate(res):
if tidx == -1:
hits.append(False)
else:
hits.append(has_answer(self.eval_dataset[qidx]["answers"], self.corpus[tidx]["content"], self.tokenizer))
return hits
def __len__(self):
return len(self.retrieval_result)
def evaluate_nq(retrieval_result: dict, eval_data: datasets.Dataset, corpus: datasets.Dataset, num_workers=16, batch_size=16, cache_dir=None):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if isinstance(eval_data, str):
eval_dataset = datasets.load_dataset("json", data_files=eval_data, split="train", cache_dir=cache_dir)
elif isinstance(eval_data, datasets.Dataset):
eval_dataset = eval_data
else:
raise ValueError(f"Expected eval_data of type str/Dataset, found {type(eval_data)}!")
if isinstance(corpus, str):
corpus = datasets.load_dataset("json", data_files=corpus, split="train", cache_dir=cache_dir)
elif isinstance(corpus, datasets.Dataset):
pass
else:
raise ValueError(f"Expected corpus of type str/Dataset, found {type(corpus)}!")
dataset = EvalDataset(retrieval_result, eval_dataset=eval_dataset, corpus=corpus)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=lambda x: x)
final_scores = []
for scores in tqdm(dataloader, total=len(dataloader), ncols=100, desc="Computing Metrics"):
final_scores.extend(scores)
relaxed_hits = np.zeros(max(*[len(x) for x in retrieval_result.values()], 100))
for question_hits in final_scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
relaxed_hits[best_hit:] += 1
relaxed_recall = relaxed_hits / len(retrieval_result)
return {
"recall@1": round(relaxed_recall[0], 4),
"recall@5": round(relaxed_recall[4], 4),
"recall@10": round(relaxed_recall[9], 4),
"recall@20": round(relaxed_recall[19], 4),
"recall@100": round(relaxed_recall[99], 4)
}
import os
import json
import logging
import inspect
import numpy as np
from tqdm import tqdm
from .evalnq import evaluate_nq
from ..utils.util import makedirs, split_file_dir_name_ext
logger = logging.getLogger(__name__)
class RetrievalMetric:
"""Class for computing metrics and some post-processings."""
@classmethod
def get_metric_fn(cls, metric_names, **kwds):
assert isinstance(metric_names, list) or isinstance(metric_names, tuple), "You must pass metric_names in a list or tuple!"
all_metrics = {}
# get all methods
all_implemented_fns = [x[0] for x in inspect.getmembers(cls, predicate=inspect.isfunction) if not x[0].startswith("_")]
def compute_metrics(*args, **kwargs):
for metric_name in metric_names:
# call corresponding method
if metric_name in all_implemented_fns:
metric_fn = getattr(cls, metric_name)
metric = metric_fn(**kwds)(*args, **kwargs)
# NOTE: some metric_fn are only used for post-processing and saving results, which return None by default
if metric is not None:
all_metrics.update(metric)
else:
raise NotImplementedError(f"Metric {metric_name} not implemented!")
return all_metrics
return compute_metrics
@staticmethod
def _get_save_path(eval_data, output_dir=None, field="result", save_name=None):
"""
if output_dir is None:
-> {eval_data_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
else:
-> {output_dir}/{eval_data_name}.{field}.{save_name}.{eval_data_ext}
"""
eval_data_dir, eval_data_name, eval_data_ext = split_file_dir_name_ext(eval_data)
if output_dir is None:
output_dir = eval_data_dir
fields = [eval_data_name, field]
if save_name is not None:
fields.append(save_name)
save_path = os.path.join(output_dir, ".".join(fields) + eval_data_ext)
makedirs(save_path)
return save_path
@staticmethod
def _save_result(query_ids, preds, result_path, scores=None):
if query_ids is None and preds is None:
logger.warning("No query_ids and preds provided for _save_result, skipping!")
return
with open(result_path, "w") as f:
for i, (query_id, pred) in enumerate(zip(query_ids, preds)):
res = {
"query_id": query_id,
"pred": pred,
}
if scores is not None:
res["score"] = scores[i]
f.write(json.dumps(res, ensure_ascii=False) + "\n")
@staticmethod
def _load_result(result_path):
logger.info(f"loading retrieval results from {result_path}...")
all_query_ids = []
all_preds = []
all_scores = None
with open(result_path) as f:
for line in f:
item = json.loads(line.strip())
all_query_ids.append(item["query_id"])
all_preds.append(item["pred"])
if "scores" in item:
if all_scores is None:
all_scores = []
all_scores.append(item["scores"])
if all_scores is not None:
return all_query_ids, all_preds, all_scores
else:
return all_query_ids, all_preds
@staticmethod
def _clean_pred(pred, score=None):
if isinstance(pred, np.ndarray):
valid_pos = pred > -1
pred = pred[valid_pos].tolist()
if score is not None:
score = score[valid_pos].tolist()
else:
valid_pos = [i for i, x in enumerate(pred) if x > -1]
pred = [pred[i] for i in valid_pos]
if score is not None:
score = [score[i] for i in valid_pos]
if score is not None:
return pred, score
else:
return pred
@staticmethod
def _prepare_label(eval_data):
labels = {}
with open(eval_data) as f:
for i, line in enumerate(f):
item = json.loads(line)
if "query_id" in item:
query_id = item["query_id"]
else:
query_id = i
# get the indices of the positives w.r.t. the corpus
label = item.get("pos_index", None)
labels[query_id] = label
return labels
@staticmethod
def mrr(eval_data=None, cutoffs=[10], **kwds):
metric_name = inspect.currentframe().f_code.co_name
if eval_data is not None:
data_labels = RetrievalMetric._prepare_label(eval_data)
def compute_metric(query_ids, preds, labels=None, **kwargs):
if labels is None:
labels = data_labels
if len(preds) != len(labels):
logger.warning(f"There are {len(preds)} queries in predictions while {len(labels)} queries in labels!")
mrrs = np.zeros(len(cutoffs))
for query_id, pred in zip(query_ids, preds):
label = labels[query_id]
pred = RetrievalMetric._clean_pred(pred)
jump = False
for i, x in enumerate(pred, 1):
if x == -1:
break
if x in label:
for k, cutoff in enumerate(cutoffs):
if i <= cutoff:
mrrs[k] += 1 / i
jump = True
if jump:
break
mrrs /= len(preds)
metric = {}
for i, cutoff in enumerate(cutoffs):
mrr = mrrs[i]
metric[f"{metric_name}@{cutoff}"] = mrr
return metric
return compute_metric
@staticmethod
def recall(eval_data=None, cutoffs=[10], **kwds):
metric_name = inspect.currentframe().f_code.co_name
if eval_data is not None:
data_labels = RetrievalMetric._prepare_label(eval_data)
def compute_metric(query_ids, preds, labels=None, **kwargs):
if labels is None:
labels = data_labels
if len(preds) != len(labels):
logger.warning(f"There are {len(preds)} queries in predictions while {len(labels)} queries in labels!")
recalls = np.zeros(len(cutoffs))
for query_id, pred in zip(query_ids, preds):
label = labels[query_id]
pred = RetrievalMetric._clean_pred(pred)
for k, cutoff in enumerate(cutoffs):
recall = np.intersect1d(label, pred[:cutoff])
recalls[k] += len(recall) / len(label)
recalls /= len(preds)
metric = {}
for i, cutoff in enumerate(cutoffs):
recall = recalls[i]
metric[f"{metric_name}@{cutoff}"] = recall
return metric
return compute_metric
@staticmethod
def ndcg(eval_data=None, cutoffs=[10], **kwds):
metric_name = inspect.currentframe().f_code.co_name
if eval_data is not None:
data_labels = RetrievalMetric._prepare_label(eval_data)
def compute_metric(query_ids, preds, labels=None, **kwargs):
if labels is None:
labels = data_labels
if len(preds) != len(labels):
logger.warning(f"There are {len(preds)} queries in predictions while {len(labels)} queries in labels!")
ndcgs = np.zeros(len(cutoffs))
for query_id, pred in zip(query_ids, preds):
label = labels[query_id]
pred = RetrievalMetric._clean_pred(pred)
ndcg = np.zeros(len(cutoffs))
idcg = np.zeros(len(cutoffs))
for i, x in enumerate(pred, 1):
if x in label:
for k, cutoff in enumerate(cutoffs):
if i <= cutoff:
ndcg[k] += 1 / np.log2(i + 1)
for j, y in enumerate(label, 1):
for k, cutoff in enumerate(cutoffs):
if j <= cutoff:
idcg[k] += 1 / np.log2(j + 1)
ndcgs += ndcg / idcg
ndcgs /= len(preds)
metric = {}
for i, cutoff in enumerate(cutoffs):
ndcg = ndcgs[i]
metric[f"{metric_name}@{cutoff}"] = ndcg
return metric
return compute_metric
@staticmethod
def nq(eval_data, corpus, cache_dir=None, **kwds):
def compute_metric(query_ids, preds, **kwargs):
# collect retrieval result
retrieval_result = {}
for i, pred in enumerate(preds):
retrieval_result[i] = RetrievalMetric._clean_pred(pred)
metrics = evaluate_nq(retrieval_result, eval_data=eval_data, corpus=corpus, cache_dir=cache_dir)
return metrics
return compute_metric
@staticmethod
def collate_key(eval_data, save_name, corpus, output_dir=None, save_to_output=False, **kwds):
"""
Collate retrieval results for evaluation.
Append a 'keys' column in the eval_data where each key is a piece of retrieved text;
Delete 'pos' and 'neg' column.
If output_dir is None, save at {eval_data}.keys.{save_name}.json
Else, save at {output_dir}.keys.{save_name}.json
"""
def collate(query_ids, preds, **kwargs):
query_id_2_pred = {}
for query_id, pred in zip(query_ids, preds):
pred = RetrievalMetric._clean_pred(pred)
query_id_2_pred[query_id] = pred
del query_ids
del preds
if save_to_output and output_dir is not None:
save_path = RetrievalMetric._get_save_path(eval_data, output_dir, field="key", save_name=save_name)
else:
save_path = RetrievalMetric._get_save_path(eval_data, None, field="key", save_name=save_name)
logger.info(f"saving key to {save_path}...")
with open(eval_data) as f, open(save_path, "w") as g:
for line in tqdm(f, desc="Collating key"):
item = json.loads(line)
query_id = item["query_id"]
# NOTE: some queries may not correspond to any keys (especially in case of BM25), just skip them
if query_id not in query_id_2_pred:
item["key"] = []
item["key_index"] = []
else:
pred = query_id_2_pred[query_id]
item["key"] = corpus[pred]["content"]
item["key_index"] = pred
# delete pos, neg, and teacher scores because they do not comply with new keys
# if "pos" in item:
# del item["pos"]
# if "neg" in item:
# del item["neg"]
# if "pos_index" in item:
# del item["pos_index"]
# if "neg_index" in item:
# del item["neg_index"]
# if "teacher_scores" in item:
# del item["teacher_scores"]
g.write(json.dumps(item, ensure_ascii=False) + "\n")
return collate
@staticmethod
def collate_neg(eval_data, save_name, corpus, max_neg_num=100, filter_answers=False, output_dir=None, save_to_output=False, **kwds):
"""
Collate retrieval results for training.
Append 'pos' and 'neg' columns in the eval_data where each element is a piece of retrieved text;
Save at {output_dir}.neg.{save_name}.json
"""
def collate(query_ids, preds, **kwargs):
query_id_2_pred = {}
for query_id, pred in zip(query_ids, preds):
pred = RetrievalMetric._clean_pred(pred)
query_id_2_pred[query_id] = pred
del query_ids
del preds
if save_to_output and output_dir is not None:
save_path = RetrievalMetric._get_save_path(eval_data, output_dir, field="neg", save_name=save_name)
else:
save_path = RetrievalMetric._get_save_path(eval_data, None, field="neg", save_name=save_name)
logger.info(f"saving {max_neg_num} negatives to {save_path}...")
with open(eval_data) as f, open(save_path, "w") as g:
for line in tqdm(f, desc="Collating Negatives"):
item = json.loads(line)
query_id = item["query_id"]
# NOTE: some queries may not correspond to any negatives (especially in case of BM25), just skip them
if query_id not in query_id_2_pred:
continue
pred = query_id_2_pred[query_id]
if "pos" in item:
pos = set(item["pos"])
else:
# sometime we do not have pre-defined pos, instead, the pos will be selected from neg based on teacher scores
pos = []
# first filter out positive documents
if "pos_index" in item:
pos_index = item["pos_index"]
pred = [i for i in pred if i != pos_index]
neg = corpus[pred]["content"]
# remove key that is the same as pos
# NOTE: here we do not use pos_index to distinguish pos and neg, because different pos_index may correpond to the same content due to duplication in the corpus
if filter_answers:
answers = item.get("answers", [])
valid_index = [i for i, x in enumerate(neg) if (x not in pos) and (not any(a.lower() in x.lower() for a in answers))]
else:
valid_index = [i for i, x in enumerate(neg) if x not in pos]
valid_index = valid_index[:max_neg_num]
neg = [neg[i] for i in valid_index]
neg_index = [pred[i] for i in valid_index]
item["neg"] = neg
item["neg_index"] = neg_index
# remove teacher scores because they are for previous pos and neg
if "teacher_scores" in item:
del item["teacher_scores"]
g.write(json.dumps(item, ensure_ascii=False) + "\n")
return collate
@staticmethod
def collate_score(eval_data, save_name, output_dir=None, save_to_output=False, **kwds):
"""
Collate scores generated by the reranking model.
Append 'teacher_scores' column in the eval_data where each element is the score of 'pos' unioned 'neg';
If output_dir is None, save at {eval_data}.score.{save_name}.json
Else, save at {output_dir}.score.{save_name}.json
"""
def collate(query_ids, preds, scores, **kwargs):
query_id_2_pred = {}
for query_id, pred, score in zip(query_ids, preds, scores):
pred, score = RetrievalMetric._clean_pred(pred, score)
query_id_2_pred[query_id] = (pred, score)
del query_ids
del preds
del scores
if save_to_output and output_dir is not None:
save_path = RetrievalMetric._get_save_path(eval_data, output_dir, field="scored", save_name=save_name)
else:
save_path = RetrievalMetric._get_save_path(eval_data, None, field="scored", save_name=save_name)
logger.info(f"saving scores to {save_path}...")
with open(eval_data) as f, open(save_path, "w") as g:
for line in tqdm(f, desc="Collating Scores"):
item = json.loads(line)
query_id = item["query_id"]
pred, score = query_id_2_pred[query_id]
# NOTE: there must be key_index
if "pos_index" in item:
key_index = item["pos_index"] + item["neg_index"]
elif "key_index" in item:
key_index = item["key_index"]
else:
key_index = list(range(len(pred)))
key_index_2_score = {k: s for k, s in zip(pred, score)}
teacher_scores = [key_index_2_score[ki] for ki in key_index]
item["teacher_scores"] = teacher_scores
g.write(json.dumps(item, ensure_ascii=False) + "\n")
return collate
import os
import json
import subprocess
import datasets
import numpy as np
from typing import List, Optional, Union
from tqdm import tqdm
from collections import defaultdict
from src.utils.util import clear_dir, split_file_dir_name_ext
class BM25Retriever:
def __init__(self, anserini_dir, k1=0.9, b=0.4, **kwds) -> None:
self.anserini_dir = anserini_dir
self.k1 = k1
self.b = b
def _prepare_collection(self, corpus:datasets.Dataset, collection_dir, max_docs_per_file=1000000):
clear_dir(collection_dir)
file_index = 0
for i, doc in enumerate(tqdm(corpus, desc="Preparing Anserini Collection")):
text = doc["content"]
if i % max_docs_per_file == 0:
if i > 0:
output_jsonl_file.close()
output_path = os.path.join(collection_dir, 'docs{:02d}.json'.format(file_index))
output_jsonl_file = open(output_path, 'w', encoding='utf-8', newline='\n')
file_index += 1
output_dict = {'id': i, 'contents': text}
output_jsonl_file.write(json.dumps(output_dict) + '\n')
output_jsonl_file.close()
def _prepare_query(self, eval_data:Union[str, datasets.Dataset], query_dir:str, max_queries_per_file=10000):
clear_dir(query_dir)
query_ids = []
queries = []
if isinstance(eval_data, str):
with open(eval_data) as f:
for line in tqdm(f, desc="Preparing Anserini Queries"):
# NOTE: repr query because it may contain newline character
item = json.loads(line)
query = repr(item["query"])[1:-1]
# filter out empty query
if len(query.strip()):
query_ids.append(item["query_id"])
queries.append(query)
elif isinstance(eval_data, datasets.Dataset):
for item in tqdm(eval_data, desc="Preparing Anserini Queries"):
# NOTE: repr query because it may contain newline character
query = repr(item["query"])[1:-1]
# filter out empty query
if len(query.strip()):
query_ids.append(item["query_id"])
queries.append(query)
else:
raise ValueError(f"Expected eval_data to be instance of str or datasets.Dataset, got {type(eval_data)}!")
# we must split large query file into smaller segments for efficiency
if len(queries) > max_queries_per_file:
# split queries into shards because Anserini cannot deal with large query file
for idx, (qid, query) in enumerate(zip(query_ids, queries)):
if idx % max_queries_per_file == 0:
if idx > 0:
g.close()
g = open(os.path.join(query_dir, f"queries.{str(idx // max_queries_per_file)}.tsv"), "w")
g.write("\t".join([str(qid), query]) + "\n")
g.close()
else:
query_path = os.path.join(query_dir, "queries.tsv")
with open(query_path, "w") as f:
for qid, qcontent in zip(query_ids, queries):
f.write("\t".join([str(qid), qcontent]) + "\n")
query_paths = []
for query_path in os.listdir(query_dir):
query_paths.append(os.path.join(query_dir, query_path))
return query_paths
def _prepare_result(self, result_path):
retrieval_result = defaultdict(list)
with open(result_path) as f:
for line in tqdm(f, desc="Collecting Retrieval Results"):
fields = line.strip().split("\t")
qid = int(fields[0])
tidx = int(fields[1])
retrieval_result[qid].append(tidx)
return retrieval_result
def index(self, corpus:Optional[datasets.Dataset]=None, output_dir:str="./bm25", threads:int=32, language:str="en", storeDocvectors:bool=False, load_collection:bool=False, load_index:bool=False, **kwds):
index_dir = os.path.join(output_dir, "index")
collection_dir = os.path.join(output_dir, "collection")
self.output_dir = output_dir
self.language = language
if not load_collection and not load_index:
self._prepare_collection(corpus, collection_dir)
if not load_index:
clear_dir(index_dir)
args = [
f"sh {self.anserini_dir}/target/appassembler/bin/IndexCollection -collection JsonCollection -generator DefaultLuceneDocumentGenerator",
f"-input {collection_dir} -index {index_dir} -threads {threads} -language {language}",
"-storeDocvectors" if storeDocvectors else ""
]
subprocess.run(" ".join(args), shell=True)
def search(self, eval_data:Union[str, datasets.Dataset], output_dir:Optional[str]=None, k1:Optional[float]=None, b:Optional[float]=None, hits:int=100, threads:int=32, parallelism:int=4, language:Optional[str]=None, max_queries_per_file:int=10000, **kwds):
if k1 is None:
k1 = self.k1
if b is None:
b = self.b
if output_dir is None and not hasattr(self, "output_dir"):
raise ValueError(f"Make sure there is an index by either calling .index() or specifying an existing index with index_dir=xxx!")
elif output_dir is None:
output_dir = self.output_dir
if language is None:
language = self.language
index_dir = os.path.join(output_dir, "index")
query_dir = os.path.join(output_dir, "query")
retrieval_result = {}
query_paths = self._prepare_query(eval_data, query_dir, max_queries_per_file)
for path in tqdm(query_paths, desc="Searching"):
tmp_result_path = path+".tmp"
args = [
f"sh {self.anserini_dir}/target/appassembler/bin/SearchCollection -topicreader TsvString -format msmarco",
f"-index {index_dir} -topics {path} -output {tmp_result_path} -bm25 -bm25.k1 {k1} -bm25.b {b}",
f"-hits {hits} -threads {threads} -parallelism {parallelism} -language {language}"
]
subprocess.run(" ".join(args), shell=True)
res = self._prepare_result(tmp_result_path)
retrieval_result.update(res)
os.remove(tmp_result_path)
return list(retrieval_result.keys()), list(retrieval_result.values())
class NaiveBM25Retriever:
def __init__(self, k1:float=0.9, b:float=0.4, **kwds) -> None:
self.k1 = k1
self.b = b
def index(self, corpus: List[Union[str, List[int]]], verbose: bool=False, stop_tokens: Optional[set]=None):
"""Build in-memory BM25 index."""
if stop_tokens is None:
stop_tokens = {}
dfs = defaultdict(int)
tfs = []
inverted_lists = defaultdict(list)
doc_lengths = np.zeros(len(corpus), dtype=np.float32)
if verbose:
iterator = tqdm(corpus, desc="Indexing")
else:
iterator = corpus
for i, doc in enumerate(iterator):
if isinstance(doc, str):
doc = doc.split(" ")
# TODO: stem
df = {}
tf = defaultdict(int)
for token in doc:
if token not in stop_tokens:
tf[token] += 1
df[token] = 1
tfs.append(dict(tf))
for token in df:
dfs[token] += 1
# store the doc offset in the inverted lists of the corresponding token
inverted_lists[token].append(i)
doc_lengths[i] = len(doc)
self.dfs = dict(dfs)
self.tfs = tfs
self.doc_length = doc_lengths
self.inverted_lists = {k: np.array(v) for k, v in inverted_lists.items()}
self.N = len(corpus)
def search(self, queries: Union[str, List[int], List[str], List[List[int]]], hits: int=100, k1: Optional[float]=None, b: Optional[float]=None, verbose: bool=False):
"""Search over the BM25 index."""
if k1 is None:
k1 = self.k1
if b is None:
b = self.b
hits = min(self.N, hits)
global_scores = np.zeros(self.N, dtype=np.float32)
if isinstance(queries, str):
queries = [queries]
elif isinstance(queries, list) and isinstance(queries[0], int):
queries = [queries]
all_scores = np.zeros((len(queries), hits), dtype=np.float32)
all_indices = np.zeros((len(queries), hits), dtype=np.int64)
if verbose:
iterator = tqdm(queries, desc="Searching")
else:
iterator = queries
for i, query in enumerate(iterator):
if isinstance(query, str):
query = query.split(" ")
# TODO: stem
for token in query:
if token in self.inverted_lists:
candidates = self.inverted_lists[token]
else:
continue
tfs = np.array([self.tfs[candidate][token] for candidate in candidates], dtype=np.float32)
df = self.dfs[token]
idf = np.log((self.N - df + 0.5) / (df + 0.5) + 1)
candidate_scores = idf * (k1 + 1) * tfs / (tfs + k1 * (1 - b + b * self.doc_length[candidates]))
global_scores[candidates] += candidate_scores
indice = np.argpartition(-global_scores, hits - 1)[:hits]
score = global_scores[indice]
sorted_idx = np.argsort(score)[::-1]
indice = indice[sorted_idx]
score = score[sorted_idx]
invalid_pos = score == 0
indice[invalid_pos] = -1
score[invalid_pos] = -float('inf')
all_scores[i] = score
all_indices[i] = indice
return all_scores, all_indices
import os
import torch
import faiss
import numpy as np
import torch.nn.functional as F
import torch.distributed as dist
from accelerate import Accelerator
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoModel, AutoTokenizer
from transformers.utils import logging
from typing import List, Mapping, Optional, Tuple, Union
from tqdm import tqdm
from .data import RetrievalDataCollator
from ..utils.util import Sequential_Sampler, makedirs, do_nothing
logger = logging.get_logger(__name__)
class DenseRetriever(torch.nn.Module):
def __init__(self, query_encoder:str='BAAI/bge-base-en', key_encoder:str='BAAI/bge-base-en', pooling_method:List[str]=["cls"], dense_metric:str="cos", query_max_length:int=512, key_max_length:int=512, tie_encoders:bool=True, truncation_side:str="right", dtype:str="fp16", cache_dir:Optional[str]=None, cos_temperature:float=0.01, contrastive_weight:float=0.2, distill_weight:float=1.0, teacher_temperature:float=1.0, student_temperature:float=1.0, negative_cross_device:bool=True, stable_distill:bool=False, accelerator:Accelerator=None, **kwds) -> None:
super().__init__()
self.accelerator = accelerator
self.tie_encoders = tie_encoders
self.pooling_method = pooling_method
self.dense_metric = dense_metric
self.query_max_length = query_max_length
self.key_max_length = key_max_length
self.cos_temperature = cos_temperature
self.contrastive_weight = contrastive_weight
self.distill_weight = distill_weight
self.teacher_temperature = teacher_temperature
self.student_temperature = student_temperature
self.negative_cross_device = negative_cross_device and dist.is_initialized()
self.stable_distill = stable_distill
logger.info(f"Loading tokenizer and model from {query_encoder}...")
self.tokenizer = AutoTokenizer.from_pretrained(query_encoder, cache_dir=cache_dir, truncation_side=truncation_side)
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
dtype = torch.float32
self.query_encoder_name = query_encoder
self.key_encoder_name = key_encoder
if tie_encoders:
encoder = AutoModel.from_pretrained(query_encoder, cache_dir=cache_dir, torch_dtype=dtype).to(self.device)
self.query_encoder = encoder
self.key_encoder = encoder
else:
self.query_encoder = AutoModel.from_pretrained(query_encoder, cache_dir=cache_dir, torch_dtype=dtype).to(self.device)
self.key_encoder = AutoModel.from_pretrained(key_encoder, cache_dir=cache_dir, torch_dtype=dtype).to(self.device)
self.ndim = self.query_encoder.config.hidden_size
self._index = None
self._post_init()
self.eval()
def _post_init(self):
"""
1. remove pooler to avoid DDP errors;
2. remove decoder when necessary
"""
if hasattr(self.query_encoder, "pooler"):
self.query_encoder.pooler = None
if hasattr(self.key_encoder, "pooler"):
self.key_encoder.pooler = None
if "dense" in self.pooling_method:
self.dense_pooler = torch.nn.Linear(self.ndim, self.ndim, bias=False).to(device=self.device, dtype=self.query_encoder.dtype)
try:
state_dict = torch.load(os.path.join(self.query_encoder_name, "dense_pooler.bin"), map_location=self.device)
self.dense_pooler.load_state_dict(state_dict)
except:
logger.warning(f"Could not find dense pooler weight in {self.query_encoder_name}, initialize it randomly!")
def gradient_checkpointing_enable(self):
self.query_encoder.gradient_checkpointing_enable()
self.key_encoder.gradient_checkpointing_enable()
@property
def device(self):
if self.accelerator is not None:
return self.accelerator.device
else:
return torch.device("cpu")
def _gather_tensors(self, local_tensor):
"""
Gather tensors from all gpus on each process.
Args:
local_tensor: the tensor that needs to be gathered
Returns:
concatenation of local_tensor in each process
"""
if local_tensor is None:
return None
all_tensors = [torch.empty_like(local_tensor)
for _ in range(self.accelerator.num_processes)]
dist.all_gather(all_tensors, local_tensor.contiguous())
all_tensors[self.accelerator.process_index] = local_tensor
return torch.cat(all_tensors, dim=0)
def _save_to_memmap(self, path: str, shape: tuple, array: np.ndarray, start: int, batch_size: int = 100000):
"""
Save to numpy array to memmap file.
"""
if self.accelerator.process_index == 0:
if os.path.exists(path):
os.remove(path)
else:
makedirs(path)
memmap = np.memmap(
path,
shape=shape,
mode="w+",
dtype=array.dtype
)
del memmap
self.accelerator.wait_for_everyone()
logger.info(f"saving array at {path}...")
memmap = np.memmap(
path,
shape=shape,
mode="r+",
dtype=array.dtype
)
array_length = array.shape[0]
# add in batch
end = start + array_length
if array_length > batch_size:
for i in tqdm(range(0, array_length, batch_size), leave=False, ncols=100):
start_idx = start + i
end_idx = min(start_idx + batch_size, end)
memmap[start_idx: end_idx] = array[i: i + (end_idx - start_idx)]
else:
memmap[start: end] = array
self.accelerator.wait_for_everyone()
def _prepare(self, inputs: Union[str, List[str], Mapping], field="key"):
"""Convert inputs into tokenized input_ids"""
if isinstance(inputs, str) or (isinstance(inputs, list) and isinstance(inputs[0], str)):
if field == "key":
inputs = self.tokenizer(
inputs, return_tensors="pt", padding=True, truncation=True, max_length=self.key_max_length)
inputs = inputs.to(self.device)
elif field == "query":
inputs = self.tokenizer(
inputs, return_tensors="pt", padding=True, truncation=True, max_length=self.query_max_length)
inputs = inputs.to(self.device)
else:
raise NotImplementedError
elif isinstance(inputs, Mapping) and "input_ids" in inputs:
if field == "key":
for k, v in inputs.items():
inputs[k] = v[:, :self.key_max_length].to(self.device)
elif field == "query":
for k, v in inputs.items():
inputs[k] = v[:, :self.query_max_length].to(self.device)
else:
raise NotImplementedError
else:
raise ValueError(f"Expected inputs of type str, list[str], or dict, got {type(inputs)}!")
return inputs
def _pool(self, embeddings, attention_mask):
if "mean" in self.pooling_method:
embeddings = embeddings.masked_fill(
~attention_mask[..., None].bool(), 0.0)
embedding = embeddings.sum(
dim=1) / attention_mask.sum(dim=1, keepdim=True)
elif "cls" in self.pooling_method:
embedding = embeddings[:, 0]
elif "decoder" in self.pooling_method:
embedding = embeddings[:, 0]
else:
raise NotImplementedError(
f"Pooling_method {self.pooling_method} not implemented!")
if "dense" in self.pooling_method:
embedding = self.dense_pooler(embedding)
return embedding
def encode(self, inputs: Union[str, List[str], Mapping], field:str="key", with_grad:bool=False):
"""Encode inputs into embeddings
Args:
inputs: can be string, list of strings, or BatchEncoding results from tokenizer
Returns:
Tensor: [batch_size, d_embed]
"""
if with_grad:
ctx_manager = do_nothing
else:
ctx_manager = torch.no_grad
with ctx_manager():
inputs = self._prepare(inputs, field=field)
if field == "key":
encoder = self.key_encoder
elif field == "query":
encoder = self.query_encoder
else:
raise ValueError(f"Field {field} not implemented!")
if hasattr(encoder, "decoder"):
# AAR uses T5 decoder to produce embedding
if "decoder" in self.pooling_method:
input_ids = inputs['input_ids']
bos_token_id = encoder.config.decoder_start_token_id
decoder_input_ids = input_ids.new_zeros(input_ids.shape[0], 1) + bos_token_id
embeddings = encoder(**inputs, decoder_input_ids=decoder_input_ids).last_hidden_state # B, 1, D
else:
# only use the encoder part
encoder = encoder.encoder
embeddings = encoder(**inputs).last_hidden_state # B, L, D
else:
embeddings = encoder(**inputs).last_hidden_state # B, L, D
embedding = self._pool(embeddings, inputs["attention_mask"])
if self.dense_metric == "cos":
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
def _compute_loss(self, query_embedding, key_embedding, teacher_scores):
if teacher_scores is not None and self.distill_weight > 0:
do_distill = True
if self.stable_distill:
teacher_targets = F.softmax(teacher_scores, dim=-1) # B N
if self.negative_cross_device:
# gather with grad
query_embeddings = self._gather_tensors(query_embedding)
key_embeddings = self._gather_tensors(key_embedding)
teacher_targets = self._gather_tensors(teacher_targets)
else:
query_embeddings = query_embedding
key_embeddings = key_embedding
teacher_targets = teacher_targets
scores = query_embeddings.matmul(key_embeddings.transpose(-1, -2)) # B, B * N
if self.dense_metric == "cos":
scores = scores / self.cos_temperature
labels = torch.arange(query_embeddings.shape[0], device=self.device)
labels = labels * (key_embeddings.shape[0] // query_embeddings.shape[0])
# labels = torch.zeros(query_embeddings.shape[0], device=self.device, dtype=torch.long)
# scores =
distill_loss = 0
group_size = key_embeddings.shape[0] // query_embeddings.shape[0]
mask = torch.zeros_like(scores)
for i in range(group_size):
temp_target = labels + i
temp_scores = scores + mask
loss = F.cross_entropy(temp_scores, temp_target, reduction="none") # B
distill_loss = distill_loss + torch.mean(teacher_targets[:, i] * loss)
mask = torch.scatter(mask, dim=-1, index=temp_target.unsqueeze(-1), value=torch.finfo(scores.dtype).min)
else:
student_query = query_embedding.unsqueeze(1) # B, 1, D
student_key = key_embedding.view(student_query.shape[0], -1, student_query.shape[-1]) # B, N, D
student_scores = student_query.matmul(student_key.transpose(-1, -2)).squeeze(1) # B, N
if self.dense_metric == "cos":
student_scores = student_scores / self.cos_temperature
student_scores = F.log_softmax(student_scores / self.student_temperature, dim=-1)
teacher_scores = F.softmax(teacher_scores / self.teacher_temperature, dim=-1)
distill_loss = F.kl_div(student_scores, teacher_scores, reduction="batchmean")
else:
do_distill = False
if self.contrastive_weight > 0:
if self.negative_cross_device:
# gather with grad
query_embedding = self._gather_tensors(query_embedding)
key_embedding = self._gather_tensors(key_embedding)
scores = query_embedding.matmul(key_embedding.transpose(-1, -2)) # B, B * N
if self.dense_metric == "cos":
scores = scores / self.cos_temperature
# in batch negative
labels = torch.arange(query_embedding.shape[0], device=self.device)
labels = labels * (key_embedding.shape[0] // query_embedding.shape[0])
contrastive_loss = F.cross_entropy(scores, labels)
do_contrastive = True
else:
do_contrastive = False
if do_distill and do_contrastive:
loss = contrastive_loss * self.contrastive_weight + distill_loss * self.distill_weight
# if self.accelerator.process_index == 0:
# print(f"distill: {distill_loss * self.distill_weight} contra: {contrastive_loss * self.contrastive_weight} sumup: {loss} contra_weight: {self.contrastive_weight} distill_weight: {self.distill_weight}\n")
elif do_distill:
loss = distill_loss
elif do_contrastive:
loss = contrastive_loss
else:
raise ValueError(f"Neither distill or contrastive learning is enabled!")
return loss
def _refresh_config(self, task):
if hasattr(self, "train_config"):
# at the first iteration, set default value
if not hasattr(self, "_contrastive_weight"):
self._contrastive_weight = self.contrastive_weight
self._distill_weight = self.distill_weight
self._teacher_temperature = self.teacher_temperature
self._student_temperature = self.student_temperature
self._stable_distill= self.stable_distill
train_config = self.train_config[task]
# when there is no setting in the train config, fall back to the default config
self.contrastive_weight = train_config.get("contrastive_weight", self._contrastive_weight)
self.distill_weight = train_config.get("distill_weight", self._distill_weight)
self.teacher_temperature = train_config.get("teacher_temperature", self._teacher_temperature)
self.student_temperature = train_config.get("student_temperature", self._student_temperature)
self.stable_distill = train_config.get("stable_distill", self._stable_distill)
def forward(self, query, key, task, teacher_scores=None, **kwds):
self._refresh_config(task)
# batch_size * (1 + nneg), ndim
key_embedding = self.encode(key, with_grad=True)
query_embedding = self.encode(query, field="query", with_grad=True) # batch_size, ndim
# for debug
# print(f"************************\n{self.accelerator.process_index}: {query['input_ids'].shape}\n {self.tokenizer.decode(query['input_ids'][0])}\n{self.contrastive_weight}\n{self.distill_weight}\n{teacher_scores[0]}")
loss = self._compute_loss(query_embedding, key_embedding, teacher_scores)
# adapted to huggingface trainer
return {"loss": loss}
@torch.no_grad()
def index(self, corpus: Dataset, output_dir="data/outputs", embedding_name=None, index_factory:str="Flat", save_index=False, load_encode=False, save_encode=False, load_index=False, batch_size=500, metric=None, **kwds):
os.makedirs(output_dir, exist_ok=True)
if embedding_name is None:
embedding_name = "embeddings"
if metric is None:
metric = self.dense_metric
encode_path = os.path.join(output_dir, f"{embedding_name}.memmap")
index_path = os.path.join(output_dir, f"{embedding_name}.{index_factory}.{self.accelerator.process_index}-{self.accelerator.num_processes}.faiss")
sampler = Sequential_Sampler(len(corpus), self.accelerator.num_processes, self.accelerator.process_index)
self._corpus_offset = sampler.start
if load_encode:
encoded_corpus = np.memmap(
encode_path,
mode="r",
dtype=np.float32
).reshape(len(corpus), self.ndim)[sampler.start: sampler.end]
else:
# use multiple workers to speed up encoding
dataloader = DataLoader(
corpus,
batch_size=batch_size,
collate_fn=RetrievalDataCollator(
query_max_length=self.query_max_length,
key_max_length=self.key_max_length,
tokenizer=self.tokenizer,
),
sampler=sampler,
pin_memory=True,
num_workers=8,
)
offset = 0
encoded_corpus = np.zeros((len(sampler), self.ndim), dtype=np.float32)
for step, inputs in enumerate(tqdm(dataloader, desc="Indexing")):
embeddings = self.encode(inputs["content"]) # batch_size, ndim
# NOTE: we cannot use non_blocking here, otherwise nothing can be saved
encoded_corpus[offset: offset + embeddings.shape[0]] = embeddings.cpu().numpy()
offset += embeddings.shape[0]
# if step > 10:
# break
if save_encode:
self._save_to_memmap(
encode_path,
shape=(len(corpus), self.ndim),
array=encoded_corpus,
start=sampler.start
)
index = FaissIndex(self.device)
if load_index:
index.load(index_path)
else:
index.build(encoded_corpus, index_factory, metric)
if save_index:
index.save(index_path)
self._index = index
self.accelerator.wait_for_everyone()
return encoded_corpus
@torch.no_grad()
def search(self, inputs: Union[str, List[str], Mapping], hits:int=10, **kwds):
assert self._index is not None, "Make sure there is an indexed corpus!"
all_scores = []
all_indices = []
embeddings = self.encode(inputs, field="query").cpu().numpy().astype(np.float32, order="C")
batch_scores, batch_indices = self._index.search(embeddings, hits)
# offset
batch_indices += self._corpus_offset
# gather and merge results from all processes
# move to cpu for faster sorting and merging
if self.accelerator.num_processes > 1:
batch_scores = torch.as_tensor(batch_scores, device=self.device)
batch_indices = torch.as_tensor(batch_indices, device=self.device)
gathered_batch_scores = self.accelerator.gather(batch_scores).unflatten(0, (self.accelerator.num_processes, -1)).tolist()
gathered_batch_indices = self.accelerator.gather(batch_indices).unflatten(0, (self.accelerator.num_processes, -1)).tolist()
else:
gathered_batch_scores = batch_scores[None, ...].tolist()
gathered_batch_indices = batch_indices[None, ...].tolist()
for batch_idx in range(batch_scores.shape[0]):
score = sum([gathered_batch_scores[i][batch_idx] for i in range(self.accelerator.num_processes)], [])
indice = sum([gathered_batch_indices[i][batch_idx] for i in range(self.accelerator.num_processes)], [])
# take care of -1s, which may be returned by faiss
pair = sorted(zip(score, indice), key=lambda x: x[0] if x[1] >= 0 else -float('inf'), reverse=True)[:hits]
all_scores.append([x[0] for x in pair])
all_indices.append([x[1] for x in pair])
all_scores = np.array(all_scores, dtype=np.float32)
all_indices = np.array(all_indices)
return all_scores, all_indices
@torch.no_grad()
def rerank(self, query, key, key_mask=None, **kwds):
query_embeddings = self.encode(query, field="query")
key_embeddings = self.encode(key)
key_embeddings = key_embeddings.unflatten(0, (query_embeddings.shape[0], -1)) # batch_size, key_num, embedding_dim
score = torch.einsum("bnd,bd->bn", key_embeddings, query_embeddings) # batch_size, key_num
# mask padded candidates
if key_mask is not None:
score = score.masked_fill(~key_mask.bool(), torch.finfo(key_embeddings.dtype).min)
score, indice = score.sort(dim=-1, descending=True)
# NOTE: set the indice to -1 so that this prediction is ignored when computing metrics
indice[score == torch.finfo(score.dtype).min] = -1
return score, indice
def save_pretrained(self, output_dir: str, *args, **kwargs):
if self.tie_encoders:
self.tokenizer.save_pretrained(
os.path.join(output_dir, "encoder"))
self.query_encoder.save_pretrained(
os.path.join(output_dir, "encoder"))
if hasattr(self, "dense_pooler"):
torch.save(self.dense_pooler.state_dict(), os.path.join(output_dir, "encoder", "dense_pooler.bin"))
else:
self.tokenizer.save_pretrained(
os.path.join(output_dir, "query_encoder"))
self.query_encoder.save_pretrained(
os.path.join(output_dir, "query_encoder"))
self.key_tokenizer.save_pretrained(
os.path.join(output_dir, "key_encoder"))
self.key_encoder.save_pretrained(
os.path.join(output_dir, "key_encoder"))
if hasattr(self, "dense_pooler"):
torch.save(self.dense_pooler.state_dict(), os.path.join(output_dir, "query_encoder", "dense_pooler.bin"))
class FaissIndex:
def __init__(self, device) -> None:
if isinstance(device, torch.device):
if device.index is None:
device = "cpu"
else:
device = device.index
self.device = device
def build(self, encoded_corpus, index_factory, metric):
if metric == "l2":
metric = faiss.METRIC_L2
elif metric in ["ip", "cos"]:
metric = faiss.METRIC_INNER_PRODUCT
else:
raise NotImplementedError(f"Metric {metric} not implemented!")
index = faiss.index_factory(encoded_corpus.shape[1], index_factory, metric)
if self.device != "cpu":
co = faiss.GpuClonerOptions()
co.useFloat16 = True
# logger.info("using fp16 on GPU...")
index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device, index, co)
logger.info("training index...")
index.train(encoded_corpus)
logger.info("adding embeddings...")
index.add(encoded_corpus)
self.index = index
return index
def load(self, index_path):
logger.info(f"loading index from {index_path}...")
index = faiss.read_index(index_path)
if self.device != "cpu":
co = faiss.GpuClonerOptions()
co.useFloat16 = True
index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device, index, co)
self.index = index
return index
def save(self, index_path):
logger.info(f"saving index at {index_path}...")
if isinstance(self.index, faiss.GpuIndex):
index = faiss.index_gpu_to_cpu(self.index)
else:
index = self.index
faiss.write_index(index, index_path)
def search(self, query, hits):
return self.index.search(query, k=hits)
import os
import torch
import torch.nn as nn
from accelerate import Accelerator
from transformers.utils import logging
from transformers import AutoTokenizer, AutoModelForSequenceClassification
logger = logging.get_logger(__name__)
class CrossEncoder(torch.nn.Module):
def __init__(self, ranker, dtype:str="fp16", cache_dir=None, accelerator:Accelerator=None) -> None:
super().__init__()
logger.info(f"Loading tokenizer and model from {ranker}...")
self.tokenizer = AutoTokenizer.from_pretrained(ranker, cache_dir=cache_dir)
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
else:
dtype = torch.float32
if accelerator is not None:
device = accelerator.device
else:
device = torch.device("cpu")
self.ranker = AutoModelForSequenceClassification.from_pretrained(ranker, num_labels=1, cache_dir=cache_dir, torch_dtype=dtype).to(device)
def gradient_checkpointing_enable(self):
self.ranker.gradient_checkpointing_enable()
def forward(self, cross, batch_size, **kwds):
output = self.ranker(**cross)
scores = output.logits.view(batch_size, -1)
loss = nn.functional.cross_entropy(scores, scores.new_zeros(scores.shape[0], dtype=torch.long))
return {"loss": loss}
@torch.no_grad()
def rerank(self, cross, batch_size, key_mask=None, hits=None, **kwds):
output = self.ranker(**cross)
score = output.logits.view(batch_size, -1)
# mask padded candidates
if key_mask is not None:
score = score.masked_fill(~key_mask.bool(), torch.finfo(score.dtype).min)
score, indice = score.sort(dim=-1, descending=True)
if hits is not None:
score = score[:, :hits]
indice = indice[:, :hits]
# NOTE: set the indice to -1 so that this prediction is ignored when computing metrics
indice[score == torch.finfo(score.dtype).min] = -1
return score, indice
def save_pretrained(self, output_dir: str, *args, **kwargs):
self.tokenizer.save_pretrained(
os.path.join(output_dir, "ranker"))
self.ranker.save_pretrained(
os.path.join(output_dir, "ranker"))
import torch
import random
import logging
from tqdm import tqdm
from .modeling_dense import DenseRetriever
from .modeling_bm25 import BM25Retriever, NaiveBM25Retriever
logger = logging.getLogger(__name__)
class Retriever:
"""A wrapper for different retrieval_methods."""
def __init__(self, retrieval_method: str="dense", **kwds) -> None:
self.retrieval_method = retrieval_method
self.accelerator = kwds["accelerator"]
if retrieval_method == "dense":
self.retriever = DenseRetriever(**kwds)
elif retrieval_method == "bm25":
if self.accelerator.process_index == 0:
self.retriever = BM25Retriever(**kwds)
else:
self.retriever = None
elif retrieval_method == "naive-bm25":
self.retriever = NaiveBM25Retriever(**kwds)
else:
logger.warning(f"Found unimplemented retrieval_method [{retrieval_method}], will return None as query_ids and preds.")
self.retriever = None
def to(self, *args, **kwds):
if hasattr(self.retriever, "to"):
self.retriever.to(*args, **kwds)
return self
def encode(self, *args, **kwds):
if self.retriever is not None and hasattr(self.retriever, "encode"):
return self.retriever.encode(*args, **kwds)
else:
raise NotImplementedError
def index(self, corpus, **kwds):
self.corpus_size = len(corpus)
if self.retriever is not None and hasattr(self.retriever, "index"):
self.retriever.index(corpus, **kwds)
self.accelerator.wait_for_everyone()
def search(self, eval_dataset, **kwds):
if self.retrieval_method == "dense":
query_ids = []
preds = [] # num_samples, hits
# every process get the same queries while searching different shards
dataloader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=kwds.get("batch_size", 1000),
pin_memory=True,
num_workers=2,
)
for step, inputs in enumerate(tqdm(dataloader, desc="Searching")):
query_id = inputs.pop("query_id")
# the indices are already gathered, merged, and sorted inside search function
score, indice = self.retriever.search(inputs["query"], **kwds) # batch_size, hits
query_ids.extend(query_id.tolist())
preds.extend(indice.tolist())
elif self.retrieval_method == "bm25" and self.retriever is not None:
query_ids, preds = self.retriever.search(eval_data=eval_dataset, **kwds)
elif self.retrieval_method == "random":
query_ids = []
preds = []
sample_range = range(self.corpus_size)
for sample in eval_dataset:
query_ids.append(sample["query_id"])
preds.append(random.sample(sample_range, kwds["hits"]))
elif self.retrieval_method == "naive-bm25":
raise NotImplementedError(f"Retrieval with naive-bm25 and dataset is not implemented!")
else:
query_ids = None
preds = None
self.accelerator.wait_for_everyone()
return query_ids, preds
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment