Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
"""
# 1. Output Search Results with BM25
python bm25_baseline.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder bm25 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
"""
import os
import sys
import datasets
from tqdm import tqdm
sys.path.append("..")
from utils.normalize_text import normalize
def generate_corpus(corpus_save_path: str):
if os.path.exists(corpus_save_path):
print("Corpus already exists. Skip generating ...")
return
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "contents": content})
corpus = datasets.Dataset.from_list(corpus_list)
corpus.to_json(corpus_save_path, force_ascii=False)
def generate_queries(qa_data_dir: str, lang: str, queries_save_dir: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
if os.path.exists(queries_save_path) and os.path.getsize(queries_save_path) > 0:
return
queries_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
queries = datasets.load_dataset('json', data_files=queries_path)['train']
queries_list = []
for data in queries:
_id = str(data['id'])
query = data['question']
queries_list.append({
'id': _id,
'content': query
})
with open(queries_save_path, 'w', encoding='utf-8') as f:
for query in queries_list:
line = f"{query['id']}\t{query['content']}"
f.write(line + '\n')
def index(corpus_save_dir: str, index_save_dir: str):
cmd = f"python -m pyserini.index.lucene \
--collection JsonCollection \
--input {corpus_save_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads 1 \
--storePositions --storeDocvectors --storeRaw \
"
os.system(cmd)
def search(index_save_dir: str, queries_save_dir: str, lang: str, result_save_path: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
# Note: Use `--lang {lang}` will cause the performance degradation, since the query and corpus are in different languages.
cmd = f"python -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {queries_save_path} \
--output {result_save_path} \
--bm25 \
--hits 1000 \
--batch-size 128 \
--threads 16 \
"
os.system(cmd)
def main():
bm25_dir = './bm25_baseline'
qa_data_dir = '../qa_data'
result_save_dir = os.path.join('./search_results', 'bm25')
if not os.path.exists(result_save_dir):
os.makedirs(result_save_dir)
corpus_save_dir = os.path.join(bm25_dir, 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
corpus_save_path = os.path.join(corpus_save_dir, 'corpus.jsonl')
generate_corpus(corpus_save_path)
index_save_dir = os.path.join(bm25_dir, 'index')
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
index(corpus_save_dir, index_save_dir)
queries_save_dir = os.path.join(bm25_dir, 'queries')
if not os.path.exists(queries_save_dir):
os.makedirs(queries_save_dir)
languages = ['ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
generate_queries(qa_data_dir, lang, queries_save_dir)
result_save_path = os.path.join(result_save_dir, f'{lang}.txt')
search(index_save_dir, queries_save_dir, lang, result_save_path)
if __name__ == '__main__':
main()
"""
# 1. Output Search Results with BM25
python bm25_baseline_same_tokenizer.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder bm25_same_tokenizer \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
"""
import os
import sys
import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
sys.path.append("..")
from utils.normalize_text import normalize
tokenizer = AutoTokenizer.from_pretrained(
'BAAI/bge-m3',
use_fast=False,
)
def _map_func_corpus(examples):
results = {}
results['id'] = examples['id']
results['contents'] = []
inputs = tokenizer(
examples['contents'],
padding=False,
truncation=True,
max_length=512
)
input_ids_list = inputs['input_ids']
for i in range(len(examples['id'])):
token_ids = input_ids_list[i][1:-1]
token_ids = [str(_id) for _id in token_ids]
results['contents'].append(" ".join(token_ids))
return results
def _map_func_query(examples):
results = {}
results['id'] = examples['id']
results['question'] = []
inputs = tokenizer(
examples['question'],
padding=False,
truncation=True,
max_length=512
)
input_ids_list = inputs['input_ids']
for i in range(len(examples['id'])):
token_ids = input_ids_list[i][1:-1]
token_ids = [str(_id) for _id in token_ids]
results['question'].append(" ".join(token_ids))
return results
def generate_corpus(corpus_save_path: str):
if os.path.exists(corpus_save_path):
print("Corpus already exists. Skip generating ...")
return
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "contents": content})
corpus = datasets.Dataset.from_list(corpus_list)
corpus = corpus.map(_map_func_corpus, batched=True, num_proc=48)
corpus.to_json(corpus_save_path, force_ascii=False)
def generate_queries(qa_data_dir: str, lang: str, queries_save_dir: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
if os.path.exists(queries_save_path) and os.path.getsize(queries_save_path) > 0:
return
queries_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
queries = datasets.load_dataset('json', data_files=queries_path)['train']
queries = queries.map(_map_func_query, batched=True, num_proc=48)
queries_list = []
for data in queries:
_id = str(data['id'])
query = data['question']
queries_list.append({
'id': _id,
'content': query
})
with open(queries_save_path, 'w', encoding='utf-8') as f:
for query in queries_list:
line = f"{query['id']}\t{query['content']}"
f.write(line + '\n')
def index(corpus_save_dir: str, index_save_dir: str):
cmd = f"python3 -m pyserini.index.lucene \
--collection JsonCollection \
--input {corpus_save_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads 1 \
--storePositions --storeDocvectors --storeRaw \
"
os.system(cmd)
def search(index_save_dir: str, queries_save_dir: str, lang: str, result_save_path: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
cmd = f"python3 -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {queries_save_path} \
--output {result_save_path} \
--bm25 \
--hits 1000 \
--batch-size 128 \
--threads 16 \
"
os.system(cmd)
def main():
qa_data_dir = '../qa_data'
bm25_dir = './bm25_baseline_same_tokenizer'
result_save_dir = os.path.join('./search_results', 'bm25_same_tokenizer')
if not os.path.exists(result_save_dir):
os.makedirs(result_save_dir)
corpus_save_dir = os.path.join(bm25_dir, 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
corpus_save_path = os.path.join(corpus_save_dir, 'corpus.jsonl')
generate_corpus(corpus_save_path)
index_save_dir = os.path.join(bm25_dir, 'index')
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
index(corpus_save_dir, index_save_dir)
queries_save_dir = os.path.join(bm25_dir, 'queries')
if not os.path.exists(queries_save_dir):
os.makedirs(queries_save_dir)
languages = ['ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
generate_queries(qa_data_dir, lang, queries_save_dir)
result_save_path = os.path.join(result_save_dir, f'{lang}.txt')
search(index_save_dir, queries_save_dir, lang, result_save_path)
if __name__ == '__main__':
main()
"""
python step0-encode_query-and-corpus.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--qa_data_dir ../qa_data \
--save_dir ./encoded_query-and-corpus \
--max_query_length 512 \
--max_passage_length 512 \
--batch_size 1024 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import sys
import json
import datasets
import numpy as np
from tqdm import tqdm
from pprint import pprint
from FlagEmbedding import BGEM3FlagModel
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
fp16: bool = field(
default=True,
metadata={'help': 'Use fp16 in inference?'}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
save_dir: str = field(
default='./encoded_query-and-corpus',
metadata={'help': 'Dir to save encoded query and corpus. Encoded query and corpus will be saved to `save_dir/{encoder_name}/{lang}/query_embd.tsv` and `save_dir/{encoder_name}/corpus/corpus_embd.jsonl`, individually.'}
)
max_query_length: int = field(
default=512,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=512,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def get_model(model_args: ModelArgs):
model = BGEM3FlagModel(
model_name_or_path=model_args.encoder,
pooling_method=model_args.pooling_method,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.fp16
)
return model
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def parse_corpus(corpus: datasets.Dataset):
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "content": content})
corpus = datasets.Dataset.from_list(corpus_list)
return corpus
def get_queries(qa_data_dir: str, lang: str):
topics_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
if not os.path.exists(topics_path):
raise FileNotFoundError(f"{topics_path} not found")
dataset = datasets.load_dataset('json', data_files=topics_path)['train']
queries_list = []
for data in dataset:
_id = str(data['id'])
query = data['question']
queries_list.append({
'id': _id,
'content': query
})
queries = datasets.Dataset.from_list(queries_list)
return queries
def encode_and_save_corpus(corpus_save_path: str, model: BGEM3FlagModel, corpus: datasets.Dataset, max_passage_length: int=512, batch_size: int=256):
docids = list(corpus["id"])
vectors = model.encode(
corpus["content"],
batch_size=batch_size,
max_length=max_passage_length,
return_dense=False,
return_sparse=True,
return_colbert_vecs=False
)['lexical_weights']
encoded_corpus_list = []
for docid, vector in zip(docids, vectors):
for key, value in vector.items():
vector[key] = int(np.ceil(value * 100))
encoded_corpus_list.append({
'id': docid,
'contents': '',
'vector': vector
})
with open(corpus_save_path, 'w', encoding='utf-8') as f:
for line in tqdm(encoded_corpus_list, desc="Saving encoded corpus"):
f.write(json.dumps(line, ensure_ascii=False) + "\n")
def encode_and_save_queries(queries_save_path: str, model: BGEM3FlagModel, queries: datasets.Dataset, max_query_length: int=512, batch_size: int=256):
qids = list(queries["id"])
vectors = model.encode(
queries["content"],
batch_size=batch_size,
max_length=max_query_length,
return_dense=False,
return_sparse=True,
return_colbert_vecs=False
)['lexical_weights']
encoded_queries_list = []
for qid, vector in zip(qids, vectors):
for key, value in vector.items():
vector[key] = int(np.ceil(value * 100))
topic_str = []
for token in vector:
topic_str += [str(token)] * vector[token]
if len(topic_str) == 0:
topic_str = "0"
else:
topic_str = " ".join(topic_str)
encoded_queries_list.append(f"{str(qid)}\t{topic_str}")
with open(queries_save_path, 'w', encoding='utf-8') as f:
for line in tqdm(encoded_queries_list, desc="Saving encoded queries"):
f.write(line + '\n')
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
# languages.reverse()
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
model = get_model(model_args=model_args)
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating embedding with model:")
print(model_args.encoder)
print('Generating corpus embedding ...')
corpus_save_dir = os.path.join(eval_args.save_dir, os.path.basename(encoder), 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
corpus_save_path = os.path.join(corpus_save_dir, 'corpus_embd.jsonl')
if os.path.exists(corpus_save_path) and os.path.getsize(corpus_save_path) > 0 and not eval_args.overwrite:
print(f'Corpus embedding already exists. Skip...')
else:
corpus = datasets.load_dataset("BeIR/nq", 'corpus')['corpus']
corpus = parse_corpus(corpus=corpus)
encode_and_save_corpus(
corpus_save_path=corpus_save_path,
model=model,
corpus=corpus,
max_passage_length=eval_args.max_passage_length,
batch_size=eval_args.batch_size
)
print('Generate query embedding of following languages: ', languages)
for lang in languages:
print("**************************************************")
save_dir = os.path.join(eval_args.save_dir, os.path.basename(encoder), lang)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
queries_save_path = os.path.join(save_dir, 'query_embd.tsv')
if os.path.exists(queries_save_path) and not eval_args.overwrite:
print(f'Query embedding of {lang} already exists. Skip...')
continue
print(f"Start generating query embedding of {lang} ...")
queries = get_queries(eval_args.qa_data_dir, lang)
encode_and_save_queries(
queries_save_path=queries_save_path,
model=model,
queries=queries,
max_query_length=eval_args.max_query_length,
batch_size=eval_args.batch_size
)
print("==================================================")
print("Finish generating embeddings with following model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
"""
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--hits 1000
"""
import os
import datasets
from tqdm import tqdm
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
encoded_query_and_corpus_save_dir: str = field(
default='./encoded_query-and-corpus',
metadata={'help': 'Dir to save encoded queries and corpus. Encoded queries and corpus are saved in `save_dir/{encoder_name}/{lang}/query_embd.tsv` and `save_dir/{encoder_name}/corpus/corpus_embd.jsonl`, individually.'}
)
result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving results. Search results will be saved to `result_save_dir/{encoder_name}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
batch_size: int = field(
default=32,
metadata={'help': 'Batch size to use during search'}
)
threads: int = field(
default=1,
metadata={'help': 'Maximum threads to use during search'}
)
hits: int = field(
default=1000,
metadata={'help': 'Number of hits'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def generate_index(corpus_embd_dir: str, index_save_dir: str, threads: int=12):
cmd = f"python -m pyserini.index.lucene \
--language en \
--collection JsonVectorCollection \
--input {corpus_embd_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads {threads} \
--impact --pretokenized --optimize \
"
os.system(cmd)
def search_and_save_results(index_save_dir: str, query_embd_path: str, result_save_path: str, batch_size: int = 32, threads: int = 12, hits: int = 1000):
cmd = f"python -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {query_embd_path} \
--output {result_save_path} \
--output-format trec \
--batch {batch_size} \
--threads {threads} \
--hits {hits} \
--impact \
"
os.system(cmd)
def parse_corpus(corpus: datasets.Dataset):
corpus_list = [{'id': e['docid'], 'content': f"{e['title']}\n{e['text']}"} for e in tqdm(corpus, desc="Generating corpus")]
corpus = datasets.Dataset.from_list(corpus_list)
return corpus
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating search results with model:")
print(model_args.encoder)
corpus_embd_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), 'corpus')
index_save_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), 'index')
if os.path.exists(index_save_dir) and not eval_args.overwrite:
print(f'Index already exists')
else:
generate_index(
corpus_embd_dir=corpus_embd_dir,
index_save_dir=index_save_dir,
threads=eval_args.threads
)
print('Generate search results of following languages: ', languages)
for lang in languages:
print("**************************************************")
print(f"Start searching results of {lang} ...")
result_save_path = os.path.join(eval_args.result_save_dir, os.path.basename(encoder), f"{lang}.txt")
if not os.path.exists(os.path.dirname(result_save_path)):
os.makedirs(os.path.dirname(result_save_path))
if os.path.exists(result_save_path) and not eval_args.overwrite:
print(f'Search results of {lang} already exists. Skip...')
continue
encoded_query_and_corpus_save_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), lang)
if not os.path.exists(encoded_query_and_corpus_save_dir):
raise FileNotFoundError(f"{encoded_query_and_corpus_save_dir} not found")
query_embd_path = os.path.join(encoded_query_and_corpus_save_dir, 'query_embd.tsv')
search_and_save_results(
index_save_dir=index_save_dir,
query_embd_path=query_embd_path,
result_save_path=result_save_path,
batch_size=eval_args.batch_size,
threads=eval_args.threads,
hits=eval_args.hits
)
print("==================================================")
print("Finish generating search results with following model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
"""
# Ref: https://github.com/texttron/tevatron/tree/main/examples/unicoil
# 1. Generate Query and Corpus Sparse Vector
python step0-encode_query-and-corpus.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--qa_data_dir ../qa_data \
--save_dir ./encoded_query-and-corpus \
--max_query_length 512 \
--max_passage_length 512 \
--batch_size 1024 \
--pooling_method cls \
--normalize_embeddings True
# 2. Output Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--hits 1000
# 3. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import sys
import json
import datasets
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
from utils.evaluation import evaluate_recall_qa
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
search_result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
metrics: str = field(
default="recall@20",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./eval_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{encoder}.json`'}
)
threads: int = field(
default=1,
metadata={"help": "num of evaluation threads. <= 1 means single thread"}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, pooling_method: str, normalize_embeddings: bool, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'model': model_name,
'pooling_method': pooling_method,
'normalize_embeddings': normalize_embeddings,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{model_name}` on `{eval_languages}` saved at `{save_path}`')
def get_corpus_dict():
corpus_dict = {}
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
for data in tqdm(corpus, desc="Loading corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_dict[_id] = content
return corpus_dict
def get_qa_dict(qa_path: str):
qa_dict = {}
dataset = datasets.load_dataset('json', data_files=qa_path)['train']
for data in dataset:
qid = str(data['id'])
answers = data['answers']
qa_dict[qid] = answers
return qa_dict
def get_search_result_dict(search_result_path: str, top_k: int=100):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = str(row.iloc[2])
rank = int(row.iloc[3])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append(docid)
return search_result_dict
def evaluate(corpus_dict: dict, qa_dict: dict, search_result_path: str, metrics: list):
top_k = max([int(metric.split('@')[-1]) for metric in metrics])
search_result_dict = get_search_result_dict(search_result_path, top_k=int(top_k))
search_results = []
ground_truths = []
for qid, docid_list in search_result_dict.items():
answers = qa_dict[qid]
doc_list = [corpus_dict[docid] for docid in docid_list]
search_results.append(doc_list)
ground_truths.append(answers)
results = {}
metrics = sorted([metric.lower() for metric in metrics])
for metric in metrics:
metric, k = metric.split('@')
k = int(k)
assert metric in ['recall'], f"Metric `{metric}` is not supported."
if metric == 'recall':
results[f'Recall@{k}'] = evaluate_recall_qa(search_results, ground_truths, k=k)
return results
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
corpus_dict = get_corpus_dict()
languages = check_languages(eval_args.languages)
if eval_args.encoder[-1] == '/':
eval_args.encoder = eval_args.encoder[:-1]
if os.path.basename(eval_args.encoder).startswith('checkpoint-'):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
results = {}
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
results_list = []
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
results_list.append(pool.apply_async(evaluate, args=(corpus_dict, qa_dict, search_result_path, eval_args.metrics)))
pool.close()
pool.join()
for i, lang in enumerate(languages):
results[lang] = results_list[i].get()
else:
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(corpus_dict, qa_dict, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=eval_args.encoder,
pooling_method=eval_args.pooling_method,
normalize_embeddings=eval_args.normalize_embeddings,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(eval_args.encoder)}.json"),
eval_languages=languages
)
print("==================================================")
print("Finish generating evaluation results with following model:")
print(eval_args.encoder)
if __name__ == "__main__":
main()
# Ref: https://github.com/facebookresearch/contriever
import regex
import unicodedata
from functools import partial
from typing import List
class SimpleTokenizer:
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
def tokenize(self, text, uncased=False):
matches = [m for m in self._regexp.finditer(text)]
if uncased:
tokens = [m.group().lower() for m in matches]
else:
tokens = [m.group() for m in matches]
return tokens
def _normalize(text):
return unicodedata.normalize('NFD', text)
def has_answer(answers, text, tokenizer) -> bool:
"""Check if a document contains an answer string."""
text = _normalize(text)
text = tokenizer.tokenize(text, uncased=True)
for answer in answers:
answer = _normalize(answer)
answer = tokenizer.tokenize(answer, uncased=True)
for i in range(0, len(text) - len(answer) + 1):
if answer == text[i: i + len(answer)]:
return True
return False
def check_answer(example, tokenizer) -> List[bool]:
"""Search through all the top docs to see if they have any of the answers."""
answers = example['answers']
ctxs = example['ctxs']
hits = []
for i, text in enumerate(ctxs):
if text is None: # cannot find the document for some reason
hits.append(False)
continue
hits.append(has_answer(answers, text, tokenizer))
return hits
def evaluate_recall_qa(ctxs, answers, k=100):
# compute Recall@k for QA task
data = []
assert len(ctxs) == len(answers)
for i in range(len(ctxs)):
_ctxs, _answers = ctxs[i], answers[i]
data.append({
'answers': _answers,
'ctxs': _ctxs,
})
tokenizer = SimpleTokenizer()
get_score_partial = partial(check_answer, tokenizer=tokenizer)
scores = map(get_score_partial, data)
n_docs = len(data[0]['ctxs'])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
k = min(k, len(top_k_hits))
return top_k_hits[k - 1] / len(data)
"""
adapted from chemdataextractor.text.normalize
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tools for normalizing text.
https://github.com/mcs07/ChemDataExtractor
:copyright: Copyright 2016 by Matt Swain.
:license: MIT
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
'Software'), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
#: Control characters.
CONTROLS = {
'\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
'\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
}
# There are further control characters, but they are instead replaced with a space by unicode normalization
# '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
#: Hyphen and dash characters.
HYPHENS = {
'-', # \u002d Hyphen-minus
'‐', # \u2010 Hyphen
'‑', # \u2011 Non-breaking hyphen
'⁃', # \u2043 Hyphen bullet
'‒', # \u2012 figure dash
'–', # \u2013 en dash
'—', # \u2014 em dash
'―', # \u2015 horizontal bar
}
#: Minus characters.
MINUSES = {
'-', # \u002d Hyphen-minus
'−', # \u2212 Minus
'-', # \uff0d Full-width Hyphen-minus
'⁻', # \u207b Superscript minus
}
#: Plus characters.
PLUSES = {
'+', # \u002b Plus
'+', # \uff0b Full-width Plus
'⁺', # \u207a Superscript plus
}
#: Slash characters.
SLASHES = {
'/', # \u002f Solidus
'⁄', # \u2044 Fraction slash
'∕', # \u2215 Division slash
}
#: Tilde characters.
TILDES = {
'~', # \u007e Tilde
'˜', # \u02dc Small tilde
'⁓', # \u2053 Swung dash
'∼', # \u223c Tilde operator #in mbert vocab
'∽', # \u223d Reversed tilde
'∿', # \u223f Sine wave
'〜', # \u301c Wave dash #in mbert vocab
'~', # \uff5e Full-width tilde #in mbert vocab
}
#: Apostrophe characters.
APOSTROPHES = {
"'", # \u0027
'’', # \u2019
'՚', # \u055a
'Ꞌ', # \ua78b
'ꞌ', # \ua78c
''', # \uff07
}
#: Single quote characters.
SINGLE_QUOTES = {
"'", # \u0027
'‘', # \u2018
'’', # \u2019
'‚', # \u201a
'‛', # \u201b
}
#: Double quote characters.
DOUBLE_QUOTES = {
'"', # \u0022
'“', # \u201c
'”', # \u201d
'„', # \u201e
'‟', # \u201f
}
#: Accent characters.
ACCENTS = {
'`', # \u0060
'´', # \u00b4
}
#: Prime characters.
PRIMES = {
'′', # \u2032
'″', # \u2033
'‴', # \u2034
'‵', # \u2035
'‶', # \u2036
'‷', # \u2037
'⁗', # \u2057
}
#: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
def normalize(text):
for control in CONTROLS:
text = text.replace(control, '')
text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
for hyphen in HYPHENS | MINUSES:
text = text.replace(hyphen, '-')
text = text.replace('\u00ad', '')
for double_quote in DOUBLE_QUOTES:
text = text.replace(double_quote, '"') # \u0022
for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
text = text.replace(single_quote, "'") # \u0027
text = text.replace('′', "'") # \u2032 prime
text = text.replace('‵', "'") # \u2035 reversed prime
text = text.replace('″', "''") # \u2033 double prime
text = text.replace('‶', "''") # \u2036 reversed double prime
text = text.replace('‴', "'''") # \u2034 triple prime
text = text.replace('‷', "'''") # \u2037 reversed triple prime
text = text.replace('⁗', "''''") # \u2057 quadruple prime
text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
for slash in SLASHES:
text = text.replace(slash, '/')
#for tilde in TILDES:
# text = text.replace(tilde, '~')
return text
# MultiLongDocRetrieval
MultiLongDocRetrieval (denoted as MLDR) is a multilingual long-document retrieval dataset. For more details, please refer to [Shitao/MLDR](https://huggingface.co/datasets/Shitao/MLDR).
## Dense Retrieval
This task has been merged into [MTEB](https://github.com/embeddings-benchmark/mteb), you can easily use mteb tool to do evaluation.
We also provide a [script](./mteb_dense_eval/eval_MLDR.py), you can use it following this command:
```bash
cd mteb_dense_eval
# Print and Save Evaluation Results with MTEB
python eval_MLDR.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--results_save_path ./results \
--max_query_length 512 \
--max_passage_length 8192 \
--batch_size 256 \
--corpus_batch_size 1 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False \
--overwrite False
```
There are some important parameters:
- `encoder`: Name or path of the model to evaluate.
- `languages`: The languages you want to evaluate on. Avaliable languages: `ar de en es fr hi it ja ko pt ru th zh`.
- `max_query_length` & `max_passage_length`: Maximum query length and maximum passage length when encoding.
- `batch_size` & `corpus_batch_size`: Batch size for query and corpus when encoding. If `max_query_length == max_passage_length`, you can ignore the `corpus_batch_size` parameter and only set `batch_size` for convenience. For faster evaluation, you should set the `batch_size` and `corpus_batch_size` as large as possible.
- `pooling_method` & `normalize_embeddings`: You should follow the corresponding setting of the model you are evaluating. For example, `BAAI/bge-m3` is `cls` and `True`, `intfloat/multilingual-e5-large` is `mean` and `True`, and `intfloat/e5-mistral-7b-instruct` is `last` and `True`.
- `add_instruction`: Whether to add instruction for query or passage when evaluating. If set `add_instruction=True`, you should also set the following parameters appropriately:
- `query_instruction_for_retrieval`: the query instruction for retrieval
- `passage_instruction_for_retrieval`: the passage instruction for retrieval
If you only add query instruction, just ignore the `passage_instruction_for_retrieval` parameter.
- `overwrite`: Whether to overwrite evaluation results.
## Hybrid Retrieval (Dense & Sparse)
If you want to perform **hybrid retrieval with both dense and sparse methods**, you can follow the following steps:
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
```bash
# install java (Linux)
apt update
apt install openjdk-11-jdk
# install pyserini
pip install pyserini
# install faiss
## CPU version
conda install -c conda-forge faiss-cpu
## GPU version
conda install -c conda-forge faiss-gpu
```
2. Download qrels from [Shitao/MLDR](https://huggingface.co/datasets/Shitao/MLDR/tree/main/qrels):
```bash
mkdir -p qrels
cd qrels
splits=(dev test)
langs=(ar de en es fr hi it ja ko pt ru th zh)
for split in ${splits[*]}; do for lang in ${langs[*]}; do wget "https://huggingface.co/datasets/Shitao/MLDR/resolve/main/qrels/qrels.mldr-v1.0-${lang}-${split}.tsv"; done; done;
```
3. Dense retrieval:
```bash
cd dense_retrieval
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--max_passage_length 8192 \
--batch_size 4 \
--fp16 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mldr.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
> Note: The evaluation results of this method may have slight differences compared to results of the method mentioned earlier (*with MTEB*), which is considered normal.
4. Sparse Retrieval
```bash
cd sparse_retrieval
# 1. Generate Query and Corpus Sparse Vector
python step0-encode_query-and-corpus.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--save_dir ./encoded_query-and-corpus \
--max_query_length 512 \
--max_passage_length 8192 \
--batch_size 1024 \
--corpus_batch_size 4 \
--pooling_method cls \
--normalize_embeddings True
# 2. Output Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000
# 3. Print and Save Evaluation Results
python step2-eval_sparse_mldr.py \
--encoder BAAI/bge-m3 \
--languages ar de es fr hi it ja ko pt ru th en zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
5. Hybrid Retrieval
```bash
cd hybrid_retrieval
# 1. Search Dense and Sparse Results
Dense Retrieval
Sparse Retrieval
# 2. Hybrid Dense and Sparse Search Results
python step0-hybrid_search_results.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--dense_search_result_save_dir ../dense_retrieval/search_results \
--sparse_search_result_save_dir ../sparse_retrieval/search_results \
--hybrid_result_save_dir ./search_results \
--top_k 1000 \
--dense_weight 0.2 --sparse_weight 0.8
# 3. Print and Save Evaluation Results
python step1-eval_hybrid_mldr.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
## MultiVector and All Rerank
If you want to perform **multi-vector reranking** or **all reranking** based on the search results of dense retrieval, you can follow the following steps:
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
```bash
# install java (Linux)
apt update
apt install openjdk-11-jdk
# install pyserini
pip install pyserini
# install faiss
## CPU version
conda install -c conda-forge faiss-cpu
## GPU version
conda install -c conda-forge faiss-gpu
```
2. Download qrels from [Shitao/MLDR](https://huggingface.co/datasets/Shitao/MLDR/tree/main/qrels):
```bash
mkdir -p qrels
cd qrels
splits=(dev test)
langs=(ar de en es fr hi it ja ko pt ru th zh)
for split in ${splits[*]}; do for lang in ${langs[*]}; do wget "https://huggingface.co/datasets/Shitao/MLDR/resolve/main/qrels/qrels.mldr-v1.0-${lang}-${split}.tsv"; done; done;
```
3. Dense retrieval:
```bash
cd dense_retrieval
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--max_passage_length 8192 \
--batch_size 4 \
--fp16 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mldr.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
> **Note**: The evaluation results of this method may have slight differences compared to results of the method mentioned earlier (*with MTEB*), which is considered normal.
4. Rerank search results with multi-vector scores or all scores:
```bash
cd multi_vector_rerank
# 1. Rerank Search Results
python step0-rerank_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ../dense_retrieval/search_results \
--rerank_result_save_dir ./rerank_results \
--top_k 200 \
--batch_size 4 \
--max_query_length 512 \
--max_passage_length 8192 \
--pooling_method cls \
--normalize_embeddings True \
--dense_weight 0.15 --sparse_weight 0.5 --colbert_weight 0.35 \
--num_shards 1 --shard_id 0 --cuda_id 0
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mldr.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./rerank_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
```
>**Note**:
>
>- You should set `dense_weight`, `sparse_weight` and `colbert_weight` based on the downstream task scenario. If the dense method performs well while the sparse method does not, you can lower `sparse_weight` and increase `dense_weight` accordingly.
>
>- Based on our experience, dividing the sentence pairs to be reranked into several shards and computing scores for each shard on a single GPU tends to be more efficient than using multiple GPUs to compute scores for all sentence pairs directly.Therefore, if your machine have multiple GPUs, you can set `num_shards` to the number of GPUs and launch multiple terminals to execute the command (`shard_id` should be equal to `cuda_id`). Therefore, if you have multiple GPUs on your machine, you can launch multiple terminals and run multiple commands simultaneously. Make sure to set the `shard_id` and `cuda_id` appropriately, and ensure that you have computed scores for all shards before proceeding to the second step.
5. (*Optional*) In the 4th step, you can get all three kinds of scores, saved to `rerank_result_save_dir/dense/{encoder}-{reranker}`, `rerank_result_save_dir/sparse/{encoder}-{reranker}` and `rerank_result_save_dir/colbert/{encoder}-{reranker}`. If you want to try other weights, you don't need to rerun the 4th step. Instead, you can use [this script](./multi_vector_rerank/hybrid_all_results.py) to hybrid the three kinds of scores directly.
```bash
cd multi_vector_rerank
# 1. Hybrid All Search Results
python hybrid_all_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--dense_search_result_save_dir ./rerank_results/dense \
--sparse_search_result_save_dir ./rerank_results/sparse \
--colbert_search_result_save_dir ./rerank_results/colbert \
--hybrid_result_save_dir ./hybrid_search_results \
--top_k 200 \
--dense_weight 0.2 --sparse_weight 0.4 --colbert_weight 0.4
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mldr.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./hybrid_search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_hybrid_results \
--metrics ndcg@10
```
## BM25 Baseline
We provide two methods of evaluating BM25 baseline:
1. Use the same tokenizer with [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) (i.e., tokenizer of [XLM-Roberta](https://huggingface.co/FacebookAI/xlm-roberta-large)):
```bash
cd sparse_retrieval
# 1. Output Search Results with BM25 (same)
python bm25_baseline_same_tokenizer.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mldr.py \
--encoder bm25_same_tokenizer \
--languages ar de es fr hi it ja ko pt ru th en zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
```
2. Use the language analyzer provided by [Anserini](https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/analysis/AnalyzerMap.java) ([Lucene Tokenizer](https://github.com/apache/lucene/tree/main/lucene/analysis/common/src/java/org/apache/lucene/analysis)):
```bash
cd sparse_retrieval
# 1. Output Search Results with BM25
python bm25_baseline.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mldr.py \
--encoder bm25 \
--languages ar de es fr hi it ja ko pt ru th en zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
```
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment