Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
"""
# 1. Output Search Results with BM25
python bm25_baseline.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder bm25 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
"""
import os
import sys
import datasets
from tqdm import tqdm
sys.path.append("..")
from utils.normalize_text import normalize
def generate_corpus(corpus_save_path: str):
if os.path.exists(corpus_save_path):
print("Corpus already exists. Skip generating ...")
return
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "contents": content})
corpus = datasets.Dataset.from_list(corpus_list)
corpus.to_json(corpus_save_path, force_ascii=False)
def generate_queries(qa_data_dir: str, lang: str, queries_save_dir: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
if os.path.exists(queries_save_path) and os.path.getsize(queries_save_path) > 0:
return
queries_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
queries = datasets.load_dataset('json', data_files=queries_path)['train']
queries_list = []
for data in queries:
_id = str(data['id'])
query = data['question']
queries_list.append({
'id': _id,
'content': query
})
with open(queries_save_path, 'w', encoding='utf-8') as f:
for query in queries_list:
line = f"{query['id']}\t{query['content']}"
f.write(line + '\n')
def index(corpus_save_dir: str, index_save_dir: str):
cmd = f"python -m pyserini.index.lucene \
--collection JsonCollection \
--input {corpus_save_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads 1 \
--storePositions --storeDocvectors --storeRaw \
"
os.system(cmd)
def search(index_save_dir: str, queries_save_dir: str, lang: str, result_save_path: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
# Note: Use `--lang {lang}` will cause the performance degradation, since the query and corpus are in different languages.
cmd = f"python -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {queries_save_path} \
--output {result_save_path} \
--bm25 \
--hits 1000 \
--batch-size 128 \
--threads 16 \
"
os.system(cmd)
def main():
bm25_dir = './bm25_baseline'
qa_data_dir = '../qa_data'
result_save_dir = os.path.join('./search_results', 'bm25')
if not os.path.exists(result_save_dir):
os.makedirs(result_save_dir)
corpus_save_dir = os.path.join(bm25_dir, 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
corpus_save_path = os.path.join(corpus_save_dir, 'corpus.jsonl')
generate_corpus(corpus_save_path)
index_save_dir = os.path.join(bm25_dir, 'index')
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
index(corpus_save_dir, index_save_dir)
queries_save_dir = os.path.join(bm25_dir, 'queries')
if not os.path.exists(queries_save_dir):
os.makedirs(queries_save_dir)
languages = ['ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
generate_queries(qa_data_dir, lang, queries_save_dir)
result_save_path = os.path.join(result_save_dir, f'{lang}.txt')
search(index_save_dir, queries_save_dir, lang, result_save_path)
if __name__ == '__main__':
main()
"""
# 1. Output Search Results with BM25
python bm25_baseline_same_tokenizer.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder bm25_same_tokenizer \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
"""
import os
import sys
import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
sys.path.append("..")
from utils.normalize_text import normalize
tokenizer = AutoTokenizer.from_pretrained(
'BAAI/bge-m3',
use_fast=False,
)
def _map_func_corpus(examples):
results = {}
results['id'] = examples['id']
results['contents'] = []
inputs = tokenizer(
examples['contents'],
padding=False,
truncation=True,
max_length=512
)
input_ids_list = inputs['input_ids']
for i in range(len(examples['id'])):
token_ids = input_ids_list[i][1:-1]
token_ids = [str(_id) for _id in token_ids]
results['contents'].append(" ".join(token_ids))
return results
def _map_func_query(examples):
results = {}
results['id'] = examples['id']
results['question'] = []
inputs = tokenizer(
examples['question'],
padding=False,
truncation=True,
max_length=512
)
input_ids_list = inputs['input_ids']
for i in range(len(examples['id'])):
token_ids = input_ids_list[i][1:-1]
token_ids = [str(_id) for _id in token_ids]
results['question'].append(" ".join(token_ids))
return results
def generate_corpus(corpus_save_path: str):
if os.path.exists(corpus_save_path):
print("Corpus already exists. Skip generating ...")
return
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "contents": content})
corpus = datasets.Dataset.from_list(corpus_list)
corpus = corpus.map(_map_func_corpus, batched=True, num_proc=48)
corpus.to_json(corpus_save_path, force_ascii=False)
def generate_queries(qa_data_dir: str, lang: str, queries_save_dir: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
if os.path.exists(queries_save_path) and os.path.getsize(queries_save_path) > 0:
return
queries_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
queries = datasets.load_dataset('json', data_files=queries_path)['train']
queries = queries.map(_map_func_query, batched=True, num_proc=48)
queries_list = []
for data in queries:
_id = str(data['id'])
query = data['question']
queries_list.append({
'id': _id,
'content': query
})
with open(queries_save_path, 'w', encoding='utf-8') as f:
for query in queries_list:
line = f"{query['id']}\t{query['content']}"
f.write(line + '\n')
def index(corpus_save_dir: str, index_save_dir: str):
cmd = f"python3 -m pyserini.index.lucene \
--collection JsonCollection \
--input {corpus_save_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads 1 \
--storePositions --storeDocvectors --storeRaw \
"
os.system(cmd)
def search(index_save_dir: str, queries_save_dir: str, lang: str, result_save_path: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
cmd = f"python3 -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {queries_save_path} \
--output {result_save_path} \
--bm25 \
--hits 1000 \
--batch-size 128 \
--threads 16 \
"
os.system(cmd)
def main():
qa_data_dir = '../qa_data'
bm25_dir = './bm25_baseline_same_tokenizer'
result_save_dir = os.path.join('./search_results', 'bm25_same_tokenizer')
if not os.path.exists(result_save_dir):
os.makedirs(result_save_dir)
corpus_save_dir = os.path.join(bm25_dir, 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
corpus_save_path = os.path.join(corpus_save_dir, 'corpus.jsonl')
generate_corpus(corpus_save_path)
index_save_dir = os.path.join(bm25_dir, 'index')
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
index(corpus_save_dir, index_save_dir)
queries_save_dir = os.path.join(bm25_dir, 'queries')
if not os.path.exists(queries_save_dir):
os.makedirs(queries_save_dir)
languages = ['ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
generate_queries(qa_data_dir, lang, queries_save_dir)
result_save_path = os.path.join(result_save_dir, f'{lang}.txt')
search(index_save_dir, queries_save_dir, lang, result_save_path)
if __name__ == '__main__':
main()
"""
python step0-encode_query-and-corpus.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--qa_data_dir ../qa_data \
--save_dir ./encoded_query-and-corpus \
--max_query_length 512 \
--max_passage_length 512 \
--batch_size 1024 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import sys
import json
import datasets
import numpy as np
from tqdm import tqdm
from pprint import pprint
from FlagEmbedding import BGEM3FlagModel
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
fp16: bool = field(
default=True,
metadata={'help': 'Use fp16 in inference?'}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
save_dir: str = field(
default='./encoded_query-and-corpus',
metadata={'help': 'Dir to save encoded query and corpus. Encoded query and corpus will be saved to `save_dir/{encoder_name}/{lang}/query_embd.tsv` and `save_dir/{encoder_name}/corpus/corpus_embd.jsonl`, individually.'}
)
max_query_length: int = field(
default=512,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=512,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def get_model(model_args: ModelArgs):
model = BGEM3FlagModel(
model_name_or_path=model_args.encoder,
pooling_method=model_args.pooling_method,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.fp16
)
return model
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def parse_corpus(corpus: datasets.Dataset):
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "content": content})
corpus = datasets.Dataset.from_list(corpus_list)
return corpus
def get_queries(qa_data_dir: str, lang: str):
topics_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
if not os.path.exists(topics_path):
raise FileNotFoundError(f"{topics_path} not found")
dataset = datasets.load_dataset('json', data_files=topics_path)['train']
queries_list = []
for data in dataset:
_id = str(data['id'])
query = data['question']
queries_list.append({
'id': _id,
'content': query
})
queries = datasets.Dataset.from_list(queries_list)
return queries
def encode_and_save_corpus(corpus_save_path: str, model: BGEM3FlagModel, corpus: datasets.Dataset, max_passage_length: int=512, batch_size: int=256):
docids = list(corpus["id"])
vectors = model.encode(
corpus["content"],
batch_size=batch_size,
max_length=max_passage_length,
return_dense=False,
return_sparse=True,
return_colbert_vecs=False
)['lexical_weights']
encoded_corpus_list = []
for docid, vector in zip(docids, vectors):
for key, value in vector.items():
vector[key] = int(np.ceil(value * 100))
encoded_corpus_list.append({
'id': docid,
'contents': '',
'vector': vector
})
with open(corpus_save_path, 'w', encoding='utf-8') as f:
for line in tqdm(encoded_corpus_list, desc="Saving encoded corpus"):
f.write(json.dumps(line, ensure_ascii=False) + "\n")
def encode_and_save_queries(queries_save_path: str, model: BGEM3FlagModel, queries: datasets.Dataset, max_query_length: int=512, batch_size: int=256):
qids = list(queries["id"])
vectors = model.encode(
queries["content"],
batch_size=batch_size,
max_length=max_query_length,
return_dense=False,
return_sparse=True,
return_colbert_vecs=False
)['lexical_weights']
encoded_queries_list = []
for qid, vector in zip(qids, vectors):
for key, value in vector.items():
vector[key] = int(np.ceil(value * 100))
topic_str = []
for token in vector:
topic_str += [str(token)] * vector[token]
if len(topic_str) == 0:
topic_str = "0"
else:
topic_str = " ".join(topic_str)
encoded_queries_list.append(f"{str(qid)}\t{topic_str}")
with open(queries_save_path, 'w', encoding='utf-8') as f:
for line in tqdm(encoded_queries_list, desc="Saving encoded queries"):
f.write(line + '\n')
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
# languages.reverse()
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
model = get_model(model_args=model_args)
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating embedding with model:")
print(model_args.encoder)
print('Generating corpus embedding ...')
corpus_save_dir = os.path.join(eval_args.save_dir, os.path.basename(encoder), 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
corpus_save_path = os.path.join(corpus_save_dir, 'corpus_embd.jsonl')
if os.path.exists(corpus_save_path) and os.path.getsize(corpus_save_path) > 0 and not eval_args.overwrite:
print(f'Corpus embedding already exists. Skip...')
else:
corpus = datasets.load_dataset("BeIR/nq", 'corpus')['corpus']
corpus = parse_corpus(corpus=corpus)
encode_and_save_corpus(
corpus_save_path=corpus_save_path,
model=model,
corpus=corpus,
max_passage_length=eval_args.max_passage_length,
batch_size=eval_args.batch_size
)
print('Generate query embedding of following languages: ', languages)
for lang in languages:
print("**************************************************")
save_dir = os.path.join(eval_args.save_dir, os.path.basename(encoder), lang)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
queries_save_path = os.path.join(save_dir, 'query_embd.tsv')
if os.path.exists(queries_save_path) and not eval_args.overwrite:
print(f'Query embedding of {lang} already exists. Skip...')
continue
print(f"Start generating query embedding of {lang} ...")
queries = get_queries(eval_args.qa_data_dir, lang)
encode_and_save_queries(
queries_save_path=queries_save_path,
model=model,
queries=queries,
max_query_length=eval_args.max_query_length,
batch_size=eval_args.batch_size
)
print("==================================================")
print("Finish generating embeddings with following model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
"""
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--hits 1000
"""
import os
import datasets
from tqdm import tqdm
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
encoded_query_and_corpus_save_dir: str = field(
default='./encoded_query-and-corpus',
metadata={'help': 'Dir to save encoded queries and corpus. Encoded queries and corpus are saved in `save_dir/{encoder_name}/{lang}/query_embd.tsv` and `save_dir/{encoder_name}/corpus/corpus_embd.jsonl`, individually.'}
)
result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving results. Search results will be saved to `result_save_dir/{encoder_name}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
batch_size: int = field(
default=32,
metadata={'help': 'Batch size to use during search'}
)
threads: int = field(
default=1,
metadata={'help': 'Maximum threads to use during search'}
)
hits: int = field(
default=1000,
metadata={'help': 'Number of hits'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def generate_index(corpus_embd_dir: str, index_save_dir: str, threads: int=12):
cmd = f"python -m pyserini.index.lucene \
--language en \
--collection JsonVectorCollection \
--input {corpus_embd_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads {threads} \
--impact --pretokenized --optimize \
"
os.system(cmd)
def search_and_save_results(index_save_dir: str, query_embd_path: str, result_save_path: str, batch_size: int = 32, threads: int = 12, hits: int = 1000):
cmd = f"python -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {query_embd_path} \
--output {result_save_path} \
--output-format trec \
--batch {batch_size} \
--threads {threads} \
--hits {hits} \
--impact \
"
os.system(cmd)
def parse_corpus(corpus: datasets.Dataset):
corpus_list = [{'id': e['docid'], 'content': f"{e['title']}\n{e['text']}"} for e in tqdm(corpus, desc="Generating corpus")]
corpus = datasets.Dataset.from_list(corpus_list)
return corpus
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating search results with model:")
print(model_args.encoder)
corpus_embd_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), 'corpus')
index_save_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), 'index')
if os.path.exists(index_save_dir) and not eval_args.overwrite:
print(f'Index already exists')
else:
generate_index(
corpus_embd_dir=corpus_embd_dir,
index_save_dir=index_save_dir,
threads=eval_args.threads
)
print('Generate search results of following languages: ', languages)
for lang in languages:
print("**************************************************")
print(f"Start searching results of {lang} ...")
result_save_path = os.path.join(eval_args.result_save_dir, os.path.basename(encoder), f"{lang}.txt")
if not os.path.exists(os.path.dirname(result_save_path)):
os.makedirs(os.path.dirname(result_save_path))
if os.path.exists(result_save_path) and not eval_args.overwrite:
print(f'Search results of {lang} already exists. Skip...')
continue
encoded_query_and_corpus_save_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), lang)
if not os.path.exists(encoded_query_and_corpus_save_dir):
raise FileNotFoundError(f"{encoded_query_and_corpus_save_dir} not found")
query_embd_path = os.path.join(encoded_query_and_corpus_save_dir, 'query_embd.tsv')
search_and_save_results(
index_save_dir=index_save_dir,
query_embd_path=query_embd_path,
result_save_path=result_save_path,
batch_size=eval_args.batch_size,
threads=eval_args.threads,
hits=eval_args.hits
)
print("==================================================")
print("Finish generating search results with following model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
"""
# Ref: https://github.com/texttron/tevatron/tree/main/examples/unicoil
# 1. Generate Query and Corpus Sparse Vector
python step0-encode_query-and-corpus.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--qa_data_dir ../qa_data \
--save_dir ./encoded_query-and-corpus \
--max_query_length 512 \
--max_passage_length 512 \
--batch_size 1024 \
--pooling_method cls \
--normalize_embeddings True
# 2. Output Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--hits 1000
# 3. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import sys
import json
import datasets
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
from utils.evaluation import evaluate_recall_qa
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
search_result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
metrics: str = field(
default="recall@20",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./eval_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{encoder}.json`'}
)
threads: int = field(
default=1,
metadata={"help": "num of evaluation threads. <= 1 means single thread"}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, pooling_method: str, normalize_embeddings: bool, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'model': model_name,
'pooling_method': pooling_method,
'normalize_embeddings': normalize_embeddings,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{model_name}` on `{eval_languages}` saved at `{save_path}`')
def get_corpus_dict():
corpus_dict = {}
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
for data in tqdm(corpus, desc="Loading corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_dict[_id] = content
return corpus_dict
def get_qa_dict(qa_path: str):
qa_dict = {}
dataset = datasets.load_dataset('json', data_files=qa_path)['train']
for data in dataset:
qid = str(data['id'])
answers = data['answers']
qa_dict[qid] = answers
return qa_dict
def get_search_result_dict(search_result_path: str, top_k: int=100):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = str(row.iloc[2])
rank = int(row.iloc[3])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append(docid)
return search_result_dict
def evaluate(corpus_dict: dict, qa_dict: dict, search_result_path: str, metrics: list):
top_k = max([int(metric.split('@')[-1]) for metric in metrics])
search_result_dict = get_search_result_dict(search_result_path, top_k=int(top_k))
search_results = []
ground_truths = []
for qid, docid_list in search_result_dict.items():
answers = qa_dict[qid]
doc_list = [corpus_dict[docid] for docid in docid_list]
search_results.append(doc_list)
ground_truths.append(answers)
results = {}
metrics = sorted([metric.lower() for metric in metrics])
for metric in metrics:
metric, k = metric.split('@')
k = int(k)
assert metric in ['recall'], f"Metric `{metric}` is not supported."
if metric == 'recall':
results[f'Recall@{k}'] = evaluate_recall_qa(search_results, ground_truths, k=k)
return results
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
corpus_dict = get_corpus_dict()
languages = check_languages(eval_args.languages)
if eval_args.encoder[-1] == '/':
eval_args.encoder = eval_args.encoder[:-1]
if os.path.basename(eval_args.encoder).startswith('checkpoint-'):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
results = {}
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
results_list = []
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
results_list.append(pool.apply_async(evaluate, args=(corpus_dict, qa_dict, search_result_path, eval_args.metrics)))
pool.close()
pool.join()
for i, lang in enumerate(languages):
results[lang] = results_list[i].get()
else:
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(corpus_dict, qa_dict, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=eval_args.encoder,
pooling_method=eval_args.pooling_method,
normalize_embeddings=eval_args.normalize_embeddings,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(eval_args.encoder)}.json"),
eval_languages=languages
)
print("==================================================")
print("Finish generating evaluation results with following model:")
print(eval_args.encoder)
if __name__ == "__main__":
main()
# Ref: https://github.com/facebookresearch/contriever
import regex
import unicodedata
from functools import partial
from typing import List
class SimpleTokenizer:
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
def tokenize(self, text, uncased=False):
matches = [m for m in self._regexp.finditer(text)]
if uncased:
tokens = [m.group().lower() for m in matches]
else:
tokens = [m.group() for m in matches]
return tokens
def _normalize(text):
return unicodedata.normalize('NFD', text)
def has_answer(answers, text, tokenizer) -> bool:
"""Check if a document contains an answer string."""
text = _normalize(text)
text = tokenizer.tokenize(text, uncased=True)
for answer in answers:
answer = _normalize(answer)
answer = tokenizer.tokenize(answer, uncased=True)
for i in range(0, len(text) - len(answer) + 1):
if answer == text[i: i + len(answer)]:
return True
return False
def check_answer(example, tokenizer) -> List[bool]:
"""Search through all the top docs to see if they have any of the answers."""
answers = example['answers']
ctxs = example['ctxs']
hits = []
for i, text in enumerate(ctxs):
if text is None: # cannot find the document for some reason
hits.append(False)
continue
hits.append(has_answer(answers, text, tokenizer))
return hits
def evaluate_recall_qa(ctxs, answers, k=100):
# compute Recall@k for QA task
data = []
assert len(ctxs) == len(answers)
for i in range(len(ctxs)):
_ctxs, _answers = ctxs[i], answers[i]
data.append({
'answers': _answers,
'ctxs': _ctxs,
})
tokenizer = SimpleTokenizer()
get_score_partial = partial(check_answer, tokenizer=tokenizer)
scores = map(get_score_partial, data)
n_docs = len(data[0]['ctxs'])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
k = min(k, len(top_k_hits))
return top_k_hits[k - 1] / len(data)
"""
adapted from chemdataextractor.text.normalize
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tools for normalizing text.
https://github.com/mcs07/ChemDataExtractor
:copyright: Copyright 2016 by Matt Swain.
:license: MIT
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
'Software'), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
#: Control characters.
CONTROLS = {
'\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
'\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
}
# There are further control characters, but they are instead replaced with a space by unicode normalization
# '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
#: Hyphen and dash characters.
HYPHENS = {
'-', # \u002d Hyphen-minus
'‐', # \u2010 Hyphen
'‑', # \u2011 Non-breaking hyphen
'⁃', # \u2043 Hyphen bullet
'‒', # \u2012 figure dash
'–', # \u2013 en dash
'—', # \u2014 em dash
'―', # \u2015 horizontal bar
}
#: Minus characters.
MINUSES = {
'-', # \u002d Hyphen-minus
'−', # \u2212 Minus
'-', # \uff0d Full-width Hyphen-minus
'⁻', # \u207b Superscript minus
}
#: Plus characters.
PLUSES = {
'+', # \u002b Plus
'+', # \uff0b Full-width Plus
'⁺', # \u207a Superscript plus
}
#: Slash characters.
SLASHES = {
'/', # \u002f Solidus
'⁄', # \u2044 Fraction slash
'∕', # \u2215 Division slash
}
#: Tilde characters.
TILDES = {
'~', # \u007e Tilde
'˜', # \u02dc Small tilde
'⁓', # \u2053 Swung dash
'∼', # \u223c Tilde operator #in mbert vocab
'∽', # \u223d Reversed tilde
'∿', # \u223f Sine wave
'〜', # \u301c Wave dash #in mbert vocab
'~', # \uff5e Full-width tilde #in mbert vocab
}
#: Apostrophe characters.
APOSTROPHES = {
"'", # \u0027
'’', # \u2019
'՚', # \u055a
'Ꞌ', # \ua78b
'ꞌ', # \ua78c
''', # \uff07
}
#: Single quote characters.
SINGLE_QUOTES = {
"'", # \u0027
'‘', # \u2018
'’', # \u2019
'‚', # \u201a
'‛', # \u201b
}
#: Double quote characters.
DOUBLE_QUOTES = {
'"', # \u0022
'“', # \u201c
'”', # \u201d
'„', # \u201e
'‟', # \u201f
}
#: Accent characters.
ACCENTS = {
'`', # \u0060
'´', # \u00b4
}
#: Prime characters.
PRIMES = {
'′', # \u2032
'″', # \u2033
'‴', # \u2034
'‵', # \u2035
'‶', # \u2036
'‷', # \u2037
'⁗', # \u2057
}
#: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
def normalize(text):
for control in CONTROLS:
text = text.replace(control, '')
text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
for hyphen in HYPHENS | MINUSES:
text = text.replace(hyphen, '-')
text = text.replace('\u00ad', '')
for double_quote in DOUBLE_QUOTES:
text = text.replace(double_quote, '"') # \u0022
for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
text = text.replace(single_quote, "'") # \u0027
text = text.replace('′', "'") # \u2032 prime
text = text.replace('‵', "'") # \u2035 reversed prime
text = text.replace('″', "''") # \u2033 double prime
text = text.replace('‶', "''") # \u2036 reversed double prime
text = text.replace('‴', "'''") # \u2034 triple prime
text = text.replace('‷', "'''") # \u2037 reversed triple prime
text = text.replace('⁗', "''''") # \u2057 quadruple prime
text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
for slash in SLASHES:
text = text.replace(slash, '/')
#for tilde in TILDES:
# text = text.replace(tilde, '~')
return text
# MultiLongDocRetrieval
MultiLongDocRetrieval (denoted as MLDR) is a multilingual long-document retrieval dataset. For more details, please refer to [Shitao/MLDR](https://huggingface.co/datasets/Shitao/MLDR).
## Dense Retrieval
This task has been merged into [MTEB](https://github.com/embeddings-benchmark/mteb), you can easily use mteb tool to do evaluation.
We also provide a [script](./mteb_dense_eval/eval_MLDR.py), you can use it following this command:
```bash
cd mteb_dense_eval
# Print and Save Evaluation Results with MTEB
python eval_MLDR.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--results_save_path ./results \
--max_query_length 512 \
--max_passage_length 8192 \
--batch_size 256 \
--corpus_batch_size 1 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False \
--overwrite False
```
There are some important parameters:
- `encoder`: Name or path of the model to evaluate.
- `languages`: The languages you want to evaluate on. Avaliable languages: `ar de en es fr hi it ja ko pt ru th zh`.
- `max_query_length` & `max_passage_length`: Maximum query length and maximum passage length when encoding.
- `batch_size` & `corpus_batch_size`: Batch size for query and corpus when encoding. If `max_query_length == max_passage_length`, you can ignore the `corpus_batch_size` parameter and only set `batch_size` for convenience. For faster evaluation, you should set the `batch_size` and `corpus_batch_size` as large as possible.
- `pooling_method` & `normalize_embeddings`: You should follow the corresponding setting of the model you are evaluating. For example, `BAAI/bge-m3` is `cls` and `True`, `intfloat/multilingual-e5-large` is `mean` and `True`, and `intfloat/e5-mistral-7b-instruct` is `last` and `True`.
- `add_instruction`: Whether to add instruction for query or passage when evaluating. If set `add_instruction=True`, you should also set the following parameters appropriately:
- `query_instruction_for_retrieval`: the query instruction for retrieval
- `passage_instruction_for_retrieval`: the passage instruction for retrieval
If you only add query instruction, just ignore the `passage_instruction_for_retrieval` parameter.
- `overwrite`: Whether to overwrite evaluation results.
## Hybrid Retrieval (Dense & Sparse)
If you want to perform **hybrid retrieval with both dense and sparse methods**, you can follow the following steps:
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
```bash
# install java (Linux)
apt update
apt install openjdk-11-jdk
# install pyserini
pip install pyserini
# install faiss
## CPU version
conda install -c conda-forge faiss-cpu
## GPU version
conda install -c conda-forge faiss-gpu
```
2. Download qrels from [Shitao/MLDR](https://huggingface.co/datasets/Shitao/MLDR/tree/main/qrels):
```bash
mkdir -p qrels
cd qrels
splits=(dev test)
langs=(ar de en es fr hi it ja ko pt ru th zh)
for split in ${splits[*]}; do for lang in ${langs[*]}; do wget "https://huggingface.co/datasets/Shitao/MLDR/resolve/main/qrels/qrels.mldr-v1.0-${lang}-${split}.tsv"; done; done;
```
3. Dense retrieval:
```bash
cd dense_retrieval
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--max_passage_length 8192 \
--batch_size 4 \
--fp16 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mldr.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
> Note: The evaluation results of this method may have slight differences compared to results of the method mentioned earlier (*with MTEB*), which is considered normal.
4. Sparse Retrieval
```bash
cd sparse_retrieval
# 1. Generate Query and Corpus Sparse Vector
python step0-encode_query-and-corpus.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--save_dir ./encoded_query-and-corpus \
--max_query_length 512 \
--max_passage_length 8192 \
--batch_size 1024 \
--corpus_batch_size 4 \
--pooling_method cls \
--normalize_embeddings True
# 2. Output Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000
# 3. Print and Save Evaluation Results
python step2-eval_sparse_mldr.py \
--encoder BAAI/bge-m3 \
--languages ar de es fr hi it ja ko pt ru th en zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
5. Hybrid Retrieval
```bash
cd hybrid_retrieval
# 1. Search Dense and Sparse Results
Dense Retrieval
Sparse Retrieval
# 2. Hybrid Dense and Sparse Search Results
python step0-hybrid_search_results.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--dense_search_result_save_dir ../dense_retrieval/search_results \
--sparse_search_result_save_dir ../sparse_retrieval/search_results \
--hybrid_result_save_dir ./search_results \
--top_k 1000 \
--dense_weight 0.2 --sparse_weight 0.8
# 3. Print and Save Evaluation Results
python step1-eval_hybrid_mldr.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
## MultiVector and All Rerank
If you want to perform **multi-vector reranking** or **all reranking** based on the search results of dense retrieval, you can follow the following steps:
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
```bash
# install java (Linux)
apt update
apt install openjdk-11-jdk
# install pyserini
pip install pyserini
# install faiss
## CPU version
conda install -c conda-forge faiss-cpu
## GPU version
conda install -c conda-forge faiss-gpu
```
2. Download qrels from [Shitao/MLDR](https://huggingface.co/datasets/Shitao/MLDR/tree/main/qrels):
```bash
mkdir -p qrels
cd qrels
splits=(dev test)
langs=(ar de en es fr hi it ja ko pt ru th zh)
for split in ${splits[*]}; do for lang in ${langs[*]}; do wget "https://huggingface.co/datasets/Shitao/MLDR/resolve/main/qrels/qrels.mldr-v1.0-${lang}-${split}.tsv"; done; done;
```
3. Dense retrieval:
```bash
cd dense_retrieval
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--max_passage_length 8192 \
--batch_size 4 \
--fp16 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mldr.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
```
> **Note**: The evaluation results of this method may have slight differences compared to results of the method mentioned earlier (*with MTEB*), which is considered normal.
4. Rerank search results with multi-vector scores or all scores:
```bash
cd multi_vector_rerank
# 1. Rerank Search Results
python step0-rerank_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ../dense_retrieval/search_results \
--rerank_result_save_dir ./rerank_results \
--top_k 200 \
--batch_size 4 \
--max_query_length 512 \
--max_passage_length 8192 \
--pooling_method cls \
--normalize_embeddings True \
--dense_weight 0.15 --sparse_weight 0.5 --colbert_weight 0.35 \
--num_shards 1 --shard_id 0 --cuda_id 0
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mldr.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./rerank_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
```
>**Note**:
>
>- You should set `dense_weight`, `sparse_weight` and `colbert_weight` based on the downstream task scenario. If the dense method performs well while the sparse method does not, you can lower `sparse_weight` and increase `dense_weight` accordingly.
>
>- Based on our experience, dividing the sentence pairs to be reranked into several shards and computing scores for each shard on a single GPU tends to be more efficient than using multiple GPUs to compute scores for all sentence pairs directly.Therefore, if your machine have multiple GPUs, you can set `num_shards` to the number of GPUs and launch multiple terminals to execute the command (`shard_id` should be equal to `cuda_id`). Therefore, if you have multiple GPUs on your machine, you can launch multiple terminals and run multiple commands simultaneously. Make sure to set the `shard_id` and `cuda_id` appropriately, and ensure that you have computed scores for all shards before proceeding to the second step.
5. (*Optional*) In the 4th step, you can get all three kinds of scores, saved to `rerank_result_save_dir/dense/{encoder}-{reranker}`, `rerank_result_save_dir/sparse/{encoder}-{reranker}` and `rerank_result_save_dir/colbert/{encoder}-{reranker}`. If you want to try other weights, you don't need to rerun the 4th step. Instead, you can use [this script](./multi_vector_rerank/hybrid_all_results.py) to hybrid the three kinds of scores directly.
```bash
cd multi_vector_rerank
# 1. Hybrid All Search Results
python hybrid_all_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--dense_search_result_save_dir ./rerank_results/dense \
--sparse_search_result_save_dir ./rerank_results/sparse \
--colbert_search_result_save_dir ./rerank_results/colbert \
--hybrid_result_save_dir ./hybrid_search_results \
--top_k 200 \
--dense_weight 0.2 --sparse_weight 0.4 --colbert_weight 0.4
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mldr.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./hybrid_search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_hybrid_results \
--metrics ndcg@10
```
## BM25 Baseline
We provide two methods of evaluating BM25 baseline:
1. Use the same tokenizer with [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) (i.e., tokenizer of [XLM-Roberta](https://huggingface.co/FacebookAI/xlm-roberta-large)):
```bash
cd sparse_retrieval
# 1. Output Search Results with BM25 (same)
python bm25_baseline_same_tokenizer.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mldr.py \
--encoder bm25_same_tokenizer \
--languages ar de es fr hi it ja ko pt ru th en zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
```
2. Use the language analyzer provided by [Anserini](https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/analysis/AnalyzerMap.java) ([Lucene Tokenizer](https://github.com/apache/lucene/tree/main/lucene/analysis/common/src/java/org/apache/lucene/analysis)):
```bash
cd sparse_retrieval
# 1. Output Search Results with BM25
python bm25_baseline.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mldr.py \
--encoder bm25 \
--languages ar de es fr hi it ja ko pt ru th en zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
```
"""
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--max_passage_length 8192 \
--batch_size 4 \
--fp16 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import faiss
import datasets
import numpy as np
from tqdm import tqdm
from FlagEmbedding import FlagModel
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
fp16: bool = field(
default=True,
metadata={'help': 'Use fp16 in inference?'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
index_save_dir: str = field(
default='./corpus-index',
metadata={'help': 'Dir to save index. Corpus index will be saved to `index_save_dir/{encoder_name}/{lang}/index`. Corpus ids will be saved to `index_save_dir/{encoder_name}/{lang}/docid` .'}
)
max_passage_length: int = field(
default=512,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def get_model(model_args: ModelArgs):
model = FlagModel(
model_args.encoder,
pooling_method=model_args.pooling_method,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.fp16
)
return model
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def load_corpus(lang: str):
corpus = datasets.load_dataset('Shitao/MLDR', f'corpus-{lang}', split='corpus')
corpus_list = [{'id': e['docid'], 'content': e['text']} for e in tqdm(corpus, desc="Generating corpus")]
corpus = datasets.Dataset.from_list(corpus_list)
return corpus
def generate_index(model: FlagModel, corpus: datasets.Dataset, max_passage_length: int=512, batch_size: int=256):
corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_passage_length)
dim = corpus_embeddings.shape[-1]
faiss_index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
corpus_embeddings = corpus_embeddings.astype(np.float32)
faiss_index.train(corpus_embeddings)
faiss_index.add(corpus_embeddings)
return faiss_index, list(corpus["id"])
def save_result(index: faiss.Index, docid: list, index_save_dir: str):
docid_save_path = os.path.join(index_save_dir, 'docid')
index_save_path = os.path.join(index_save_dir, 'index')
with open(docid_save_path, 'w', encoding='utf-8') as f:
for _id in docid:
f.write(str(_id) + '\n')
faiss.write_index(index, index_save_path)
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
model = get_model(model_args=model_args)
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating embedding with model:")
print(model_args.encoder)
print('Generate embedding of following languages: ', languages)
for lang in languages:
print("**************************************************")
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder), lang)
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
if os.path.exists(os.path.join(index_save_dir, 'index')) and not eval_args.overwrite:
print(f'Embedding of {lang} already exists. Skip...')
continue
print(f"Start generating embedding of {lang} ...")
corpus = load_corpus(lang)
index, docid = generate_index(
model=model,
corpus=corpus,
max_passage_length=eval_args.max_passage_length,
batch_size=eval_args.batch_size
)
save_result(index, docid, index_save_dir)
print("==================================================")
print("Finish generating embeddings with model:")
print(model_args.encoder)
if __name__ == "__main__":
main()
"""
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
"""
import os
import torch
import datasets
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser, is_torch_npu_available
from pyserini.search.faiss import FaissSearcher, AutoQueryEncoder
from pyserini.output_writer import get_output_writer, OutputFormat
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
add_instruction: bool = field(
default=False,
metadata={'help': 'Add instruction?'}
)
query_instruction_for_retrieval: str = field(
default=None,
metadata={'help': 'query instruction for retrieval'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
index_save_dir: str = field(
default='./corpus-index',
metadata={'help': 'Dir to index and docid. Corpus index path is `index_save_dir/{encoder_name}/{lang}/index`. Corpus ids path is `index_save_dir/{encoder_name}/{lang}/docid` .'}
)
result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving search results. Search results will be saved to `result_save_dir/{encoder_name}/{lang}.txt`'}
)
threads: int = field(
default=1,
metadata={'help': 'Maximum threads to use during search'}
)
hits: int = field(
default=1000,
metadata={'help': 'Number of hits'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def get_query_encoder(model_args: ModelArgs):
if torch.cuda.is_available():
device = torch.device("cuda")
elif is_torch_npu_available():
device = torch.device("npu")
else:
device = torch.device("cpu")
model = AutoQueryEncoder(
encoder_dir=model_args.encoder,
device=device,
pooling=model_args.pooling_method,
l2_norm=model_args.normalize_embeddings
)
return model
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_queries_and_qids(lang: str, split: str='test', add_instruction: bool=False, query_instruction_for_retrieval: str=None):
dataset = datasets.load_dataset('Shitao/MLDR', lang, split=split)
queries = []
qids = []
for data in dataset:
qids.append(str(data['query_id']))
queries.append(str(data['query']))
if add_instruction and query_instruction_for_retrieval is not None:
queries = [f"{query_instruction_for_retrieval}{query}" for query in queries]
return queries, qids
def save_result(search_results, result_save_path: str, qids: list, max_hits: int):
output_writer = get_output_writer(result_save_path, OutputFormat(OutputFormat.TREC.value), 'w',
max_hits=max_hits, tag='Faiss', topics=qids,
use_max_passage=False,
max_passage_delimiter='#',
max_passage_hits=1000)
with output_writer:
for topic, hits in search_results:
output_writer.write(topic, hits)
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
query_encoder = get_query_encoder(model_args=model_args)
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating search results with model:")
print(model_args.encoder)
print('Generate search results of following languages: ', languages)
for lang in languages:
print("**************************************************")
print(f"Start searching results of {lang} ...")
result_save_path = os.path.join(eval_args.result_save_dir, os.path.basename(encoder), f"{lang}.txt")
if not os.path.exists(os.path.dirname(result_save_path)):
os.makedirs(os.path.dirname(result_save_path))
if os.path.exists(result_save_path) and not eval_args.overwrite:
print(f'Search results of {lang} already exists. Skip...')
continue
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder), lang)
if not os.path.exists(index_save_dir):
raise FileNotFoundError(f"{index_save_dir} not found")
searcher = FaissSearcher(
index_dir=index_save_dir,
query_encoder=query_encoder
)
queries, qids = get_queries_and_qids(
lang=lang,
split='test',
add_instruction=model_args.add_instruction,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval
)
search_results = searcher.batch_search(
queries=queries,
q_ids=qids,
k=eval_args.hits,
threads=eval_args.threads
)
search_results = [(_id, search_results[_id]) for _id in qids]
save_result(
search_results=search_results,
result_save_path=result_save_path,
qids=qids,
max_hits=eval_args.hits
)
print("==================================================")
print("Finish generating search results with model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
"""
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--max_passage_length 8192 \
--batch_size 4 \
--fp16 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mldr.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import json
import platform
import subprocess
import numpy as np
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from pyserini.util import download_evaluation_script
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": '+'}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
search_result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}/{lang}.txt`'}
)
qrels_dir: str = field(
default='../qrels',
metadata={'help': 'Dir to qrels.'}
)
metrics: str = field(
default="ndcg@10",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: ndcg@k, recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./eval_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{encoder}.json`'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, pooling_method: str, normalize_embeddings: bool, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'model': model_name,
'pooling_method': pooling_method,
'normalize_embeddings': normalize_embeddings,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{model_name}` on `{eval_languages}` saved at `{save_path}`')
def map_metric(metric: str):
metric, k = metric.split('@')
if metric.lower() == 'ndcg':
return k, f'ndcg_cut.{k}'
elif metric.lower() == 'recall':
return k, f'recall.{k}'
else:
raise ValueError(f"Unkown metric: {metric}")
def evaluate(script_path, qrels_path, search_result_path, metrics: list):
cmd_prefix = ['java', '-jar', script_path]
results = {}
for metric in metrics:
k, mapped_metric = map_metric(metric)
args = ['-c', '-M', str(k), '-m', mapped_metric, qrels_path, search_result_path]
cmd = cmd_prefix + args
# print(f'Running command: {cmd}')
shell = platform.system() == "Windows"
process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=shell)
stdout, stderr = process.communicate()
if stderr:
print(stderr.decode("utf-8"))
result_str = stdout.decode("utf-8")
try:
results[metric] = float(result_str.split(' ')[-1].split('\t')[-1])
except:
results[metric] = result_str
return results
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
script_path = download_evaluation_script('trec_eval')
if eval_args.encoder[-1] == '/':
eval_args.encoder = eval_args.encoder[:-1]
encoder = eval_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
results = {}
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qrels_path = os.path.join(eval_args.qrels_dir, f"qrels.mldr-v1.0-{lang}-test.tsv")
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(script_path, qrels_path, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=encoder,
pooling_method=eval_args.pooling_method,
normalize_embeddings=eval_args.normalize_embeddings,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(encoder)}.json"),
eval_languages=languages
)
print("==================================================")
print("Finish generating evaluation results with model:")
print(eval_args.encoder)
if __name__ == "__main__":
main()
"""
python step0-hybrid_search_results.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--dense_search_result_save_dir ../dense_retrieval/search_results \
--sparse_search_result_save_dir ../sparse_retrieval/search_results \
--hybrid_result_save_dir ./search_results \
--top_k 1000 \
--dense_weight 0.2 --sparse_weight 0.8
"""
import os
import pandas as pd
from tqdm import tqdm
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class EvalArgs:
model_name_or_path: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of model'}
)
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
top_k: int = field(
default=1000,
metadata={'help': 'Use reranker to rerank top-k retrieval results'}
)
sparse_weight: float = field(
default=0.8,
metadata={'help': 'Hybrid weight of sparse score'}
)
dense_weight: float = field(
default=0.2,
metadata={'help': 'Hybrid weight of dense score'}
)
dense_search_result_save_dir: str = field(
default='../dense_retrieval/search_results',
metadata={'help': 'Dir to saving dense search results. Search results path is `dense_search_result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
sparse_search_result_save_dir: str = field(
default='../sparse_retrieval/search_results',
metadata={'help': 'Dir to saving sparse search results. Search results path is `sparse_search_result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
hybrid_result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving hybrid search results. Reranked results will be saved to `hybrid_result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_search_result_dict(search_result_path: str, top_k: int=1000):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = row.iloc[2]
rank = int(row.iloc[3])
score = float(row.iloc[4])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append((docid, score))
return search_result_dict
def save_hybrid_results(sparse_search_result_dict: dict, dense_search_result_dict: dict, hybrid_result_save_path: str, top_k: int=1000, dense_weight: float=0.2, sparse_weight: float=0.8):
if not os.path.exists(os.path.dirname(hybrid_result_save_path)):
os.makedirs(os.path.dirname(hybrid_result_save_path))
qid_list = list(set(sparse_search_result_dict.keys()) | set(dense_search_result_dict.keys()))
hybrid_results_list = []
for qid in tqdm(qid_list, desc="Hybriding dense and sparse scores"):
results = {}
if qid in sparse_search_result_dict:
for docid, score in sparse_search_result_dict[qid]:
score = score / 10000.
results[docid] = score * sparse_weight
if qid in dense_search_result_dict:
for docid, score in dense_search_result_dict[qid]:
if docid in results:
results[docid] = results[docid] + score * dense_weight
else:
results[docid] = score * dense_weight
hybrid_results = [(docid, score) for docid, score in results.items()]
hybrid_results.sort(key=lambda x: x[1], reverse=True)
hybrid_results_list.append(hybrid_results[:top_k])
with open(hybrid_result_save_path, 'w', encoding='utf-8') as f:
for qid, hybrid_results in tqdm(zip(qid_list, hybrid_results_list), desc="Saving hybrid search results"):
for rank, docid_score in enumerate(hybrid_results):
docid, score = docid_score
line = f"{qid} Q0 {docid} {rank+1} {score:.6f} Faiss-&-Anserini"
f.write(line + '\n')
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if os.path.basename(eval_args.model_name_or_path).startswith('checkpoint-'):
eval_args.model_name_or_path = os.path.dirname(eval_args.model_name_or_path) + '_' + os.path.basename(eval_args.model_name_or_path)
for lang in languages:
print("**************************************************")
print(f"Start hybrid search results of {lang} ...")
hybrid_result_save_path = os.path.join(eval_args.hybrid_result_save_dir, f"{os.path.basename(eval_args.model_name_or_path)}", f"{lang}.txt")
sparse_search_result_save_dir = os.path.join(eval_args.sparse_search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
sparse_search_result_path = os.path.join(sparse_search_result_save_dir, f"{lang}.txt")
sparse_search_result_dict = get_search_result_dict(sparse_search_result_path, top_k=eval_args.top_k)
dense_search_result_save_dir = os.path.join(eval_args.dense_search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
dense_search_result_path = os.path.join(dense_search_result_save_dir, f"{lang}.txt")
dense_search_result_dict = get_search_result_dict(dense_search_result_path, top_k=eval_args.top_k)
save_hybrid_results(
sparse_search_result_dict=sparse_search_result_dict,
dense_search_result_dict=dense_search_result_dict,
hybrid_result_save_path=hybrid_result_save_path,
top_k=eval_args.top_k,
sparse_weight=eval_args.sparse_weight,
dense_weight=eval_args.dense_weight
)
print("==================================================")
print("Finish generating reranked results with following model:")
print(eval_args.model_name_or_path)
if __name__ == "__main__":
main()
"""
# 1. Search Dense and Sparse Results
../dense_retrieval
../sparse_retrieval
# 2. Hybrid Dense and Sparse Search Results
python step0-hybrid_search_results.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--dense_search_result_save_dir ../dense_retrieval/search_results \
--sparse_search_result_save_dir ../sparse_retrieval/search_results \
--hybrid_result_save_dir ./search_results \
--top_k 1000 \
--dense_weight 0.2 --sparse_weight 0.8
# 3. Print and Save Evaluation Results
python step1-eval_hybrid_mldr.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import json
import platform
import subprocess
import numpy as np
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from pyserini.util import download_evaluation_script
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
model_name_or_path: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of model'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
search_result_save_dir: str = field(
default='./output_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
qrels_dir: str = field(
default='../qrels',
metadata={'help': 'Dir to qrels.'}
)
metrics: str = field(
default="ndcg@10",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: ndcg@k, recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./eval_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{model_name_or_path}.json`'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, pooling_method: str, normalize_embeddings: bool, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'model': model_name,
'pooling_method': pooling_method,
'normalize_embeddings': normalize_embeddings,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{model_name}` on `{eval_languages}` saved at `{save_path}`')
def map_metric(metric: str):
metric, k = metric.split('@')
if metric.lower() == 'ndcg':
return k, f'ndcg_cut.{k}'
elif metric.lower() == 'recall':
return k, f'recall.{k}'
else:
raise ValueError(f"Unkown metric: {metric}")
def evaluate(script_path, qrels_path, search_result_path, metrics: list):
cmd_prefix = ['java', '-jar', script_path]
results = {}
for metric in metrics:
k, mapped_metric = map_metric(metric)
args = ['-c', '-M', str(k), '-m', mapped_metric, qrels_path, search_result_path]
cmd = cmd_prefix + args
# print(f'Running command: {cmd}')
shell = platform.system() == "Windows"
process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=shell)
stdout, stderr = process.communicate()
if stderr:
print(stderr.decode("utf-8"))
result_str = stdout.decode("utf-8")
try:
results[metric] = float(result_str.split(' ')[-1].split('\t')[-1])
except:
results[metric] = result_str
return results
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
script_path = download_evaluation_script('trec_eval')
if eval_args.model_name_or_path[-1] == '/':
eval_args.model_name_or_path = eval_args.model_name_or_path[:-1]
if os.path.basename(eval_args.model_name_or_path).startswith('checkpoint-'):
eval_args.model_name_or_path = os.path.dirname(eval_args.model_name_or_path) + '_' + os.path.basename(eval_args.model_name_or_path)
results = {}
for lang in languages:
qrels_path = os.path.join(eval_args.qrels_dir, f"qrels.mldr-v1.0-{lang}-test.tsv")
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(script_path, qrels_path, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=eval_args.model_name_or_path,
pooling_method=eval_args.pooling_method,
normalize_embeddings=eval_args.normalize_embeddings,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(eval_args.model_name_or_path)}.json"),
eval_languages=languages
)
print("==================================================")
print("Finish generating evaluation results with following model:")
print(eval_args.model_name_or_path)
if __name__ == "__main__":
main()
"""
python3 eval_MLDR.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--results_save_path ./results \
--max_query_length 512 \
--max_passage_length 8192 \
--batch_size 256 \
--corpus_batch_size 1 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False \
--overwrite False
"""
import os
from mteb import MTEB
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from flag_dres_model import FlagDRESModel
# from mteb.tasks import MultiLongDocRetrieval
from C_MTEB.tasks.MultiLongDocRetrieval import MultiLongDocRetrieval
@dataclass
class EvalArgs:
results_save_path: str = field(
default='./results',
metadata={'help': 'Path to save results.'}
)
languages: str = field(
default=None,
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
overwrite: bool = field(
default=False,
metadata={"help": "whether to overwrite evaluation results"}
)
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'encoder name or path.'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean', 'last'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
add_instruction: bool = field(
default=False,
metadata={'help': 'Add instruction?'}
)
query_instruction_for_retrieval: str = field(
default=None,
metadata={'help': 'query instruction for retrieval'}
)
passage_instruction_for_retrieval: str = field(
default=None,
metadata={'help': 'passage instruction for retrieval'}
)
max_query_length: int = field(
default=512,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=8192,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
corpus_batch_size: int = field(
default=2,
metadata={'help': 'Inference batch size for corpus. If 0, then use `batch_size`.'}
)
def check_languages(languages):
if languages is None:
return None
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
encoder = model_args.encoder
if encoder[-1] == '/':
encoder = encoder[:-1]
model = FlagDRESModel(
model_name_or_path=encoder,
pooling_method=model_args.pooling_method,
normalize_embeddings=model_args.normalize_embeddings,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval if model_args.add_instruction else None,
passage_instruction_for_retrieval=model_args.passage_instruction_for_retrieval if model_args.add_instruction else None,
max_query_length=model_args.max_query_length,
max_passage_length=model_args.max_passage_length,
batch_size=model_args.batch_size,
corpus_batch_size=model_args.corpus_batch_size
)
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
output_folder = os.path.join(eval_args.results_save_path, f'{os.path.basename(encoder)}_max-length-{model_args.max_passage_length}')
print("==================================================")
print("Start evaluating model:")
print(model_args.encoder)
evaluation = MTEB(tasks=[
MultiLongDocRetrieval(langs=languages)
])
results_dict = evaluation.run(model, eval_splits=["test"], output_folder=output_folder, overwrite_results=eval_args.overwrite, corpus_chunk_size=200000)
print(output_folder + ":")
pprint(results_dict)
print("==================================================")
print("Finish MultiLongDocRetrieval evaluation for model:")
print(model_args.encoder)
if __name__ == "__main__":
main()
import torch
import datasets
import numpy as np
from tqdm import tqdm
from mteb import DRESModel
from functools import partial
from torch.utils.data import DataLoader
from typing import cast, List, Dict, Union
from transformers import AutoModel, AutoTokenizer, is_torch_npu_available
from transformers import PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
def _transform_func(examples: Dict[str, List],
tokenizer: PreTrainedTokenizerFast,
max_length: int) -> BatchEncoding:
return tokenizer(examples['text'],
max_length=max_length,
padding=True,
return_token_type_ids=False,
truncation=True,
return_tensors='pt')
def _transform_func_v2(examples: Dict[str, List],
tokenizer: PreTrainedTokenizerFast,
max_length: int=8192,
) -> BatchEncoding:
inputs = tokenizer(examples['text'],
max_length=max_length - 1,
padding=False,
return_attention_mask=False,
truncation=True)
inputs['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in inputs['input_ids']]
inputs = tokenizer.pad(inputs, padding=True, return_attention_mask=True, return_tensors='pt')
return inputs
class FlagDRESModel(DRESModel):
def __init__(
self,
model_name_or_path: str = None,
pooling_method: str = 'cls',
normalize_embeddings: bool = True,
use_fp16: bool = True,
query_instruction_for_retrieval: str = None,
passage_instruction_for_retrieval: str = None,
max_query_length: int = 512,
max_passage_length: int = 8192,
batch_size: int = 256,
corpus_batch_size: int = 0,
**kwargs
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if 'jina' in model_name_or_path:
self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
else:
self.model = AutoModel.from_pretrained(model_name_or_path)
self.query_instruction_for_retrieval = query_instruction_for_retrieval
self.passage_instruction_for_retrieval = passage_instruction_for_retrieval
self.normalize_embeddings = normalize_embeddings
self.pooling_method = pooling_method
self.batch_size = batch_size
self.corpus_batch_size = corpus_batch_size if corpus_batch_size > 0 else batch_size
self.max_query_length = max_query_length
self.max_passage_length = max_passage_length
if use_fp16: self.model.half()
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
self.model = self.model.to(self.device)
self.num_gpus = torch.cuda.device_count()
if self.num_gpus > 1:
self.model = torch.nn.DataParallel(self.model)
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
'''
This function will be used for retrieval task
if there is a instruction for queries, we will add it to the query text
'''
if isinstance(queries[0], dict):
if self.query_instruction_for_retrieval is not None:
input_texts = ['{}{}'.format(self.query_instruction_for_retrieval, q['text']) for q in queries]
else:
input_texts = [q['text'] for q in queries]
else:
if self.query_instruction_for_retrieval is not None:
input_texts = ['{}{}'.format(self.query_instruction_for_retrieval, q) for q in queries]
else:
input_texts = queries
return self.encode(input_texts, max_length=self.max_query_length, batch_size=self.batch_size)
def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
'''
This function will be used for retrieval task
encode corpus for retrieval task
'''
if isinstance(corpus[0], dict):
if self.passage_instruction_for_retrieval is not None:
input_texts = ['{}{} {}'.format(self.passage_instruction_for_retrieval, doc.get('title', ''), doc['text']).strip() for doc in corpus]
else:
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
else:
if self.passage_instruction_for_retrieval is not None:
input_texts = self.passage_instruction_for_retrieval + corpus
else:
input_texts = corpus
return self.encode(input_texts, max_length=self.max_passage_length, batch_size=self.corpus_batch_size)
@torch.no_grad()
def encode(self, sentences: List[str], max_length: int, batch_size: int, **kwargs) -> np.ndarray:
if self.num_gpus > 0:
batch_size = batch_size * self.num_gpus
self.model.eval()
input_was_string = False
if isinstance(sentences, str):
sentences = [sentences]
input_was_string = True
dataset = datasets.Dataset.from_dict({'text': sentences})
if self.pooling_method == 'last':
dataset.set_transform(partial(_transform_func_v2, tokenizer=self.tokenizer, max_length=max_length))
else:
dataset.set_transform(partial(_transform_func, tokenizer=self.tokenizer, max_length=max_length))
data_collator = DataCollatorWithPadding(self.tokenizer)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=4,
collate_fn=data_collator,
# pin_memory=True
)
all_embeddings = []
for batch_data in tqdm(data_loader, desc='encoding', mininterval=10):
batch_data = batch_data.to(self.device)
# print(batch_data)
last_hidden_state = self.model(**batch_data, return_dict=True).last_hidden_state
# print(last_hidden_state)
embeddings = self.pooling(last_hidden_state, batch_data['attention_mask']).float()
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
all_embeddings.append(embeddings.cpu().numpy())
all_embeddings = np.concatenate(all_embeddings, axis=0)
if input_was_string:
return all_embeddings[0]
else:
return all_embeddings
def pooling(self,
last_hidden_state: torch.Tensor,
attention_mask: torch.Tensor=None):
if self.pooling_method == 'cls':
return last_hidden_state[:, 0]
elif self.pooling_method == 'mean':
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d
elif self.pooling_method == 'last':
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_state[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
return last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
"""
python hybrid_all_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--dense_search_result_save_dir ./rerank_results/dense \
--sparse_search_result_save_dir ./rerank_results/sparse \
--colbert_search_result_save_dir ./rerank_results/colbert \
--hybrid_result_save_dir ./hybrid_search_results \
--top_k 200 \
--dense_weight 0.2 --sparse_weight 0.4 --colbert_weight 0.4
"""
import os
import pandas as pd
from tqdm import tqdm
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class EvalArgs:
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of model'}
)
reranker: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of reranker'}
)
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
top_k: int = field(
default=200,
metadata={'help': 'Use reranker to rerank top-k retrieval results'}
)
dense_weight: float = field(
default=0.15,
metadata={'help': 'Hybrid weight of dense score'}
)
sparse_weight: float = field(
default=0.5,
metadata={'help': 'Hybrid weight of sparse score'}
)
colbert_weight: float = field(
default=0.35,
metadata={'help': 'Hybrid weight of colbert score'}
)
dense_search_result_save_dir: str = field(
default='../rerank/unify_rerank_results/dense',
metadata={'help': 'Dir to saving dense search results. Search results path is `dense_search_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
sparse_search_result_save_dir: str = field(
default='../rerank/unify_rerank_results/sparse',
metadata={'help': 'Dir to saving sparse search results. Search results path is `sparse_search_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
colbert_search_result_save_dir: str = field(
default='../rerank/unify_rerank_results/colbert',
metadata={'help': 'Dir to saving sparse search results. Search results path is `sparse_search_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
hybrid_result_save_dir: str = field(
default='./hybrid_search_results',
metadata={'help': 'Dir to saving hybrid search results. Reranked results will be saved to `hybrid_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_search_result_dict(search_result_path: str, top_k: int=1000):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = row.iloc[2]
rank = int(row.iloc[3])
score = float(row.iloc[4])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append((docid, score))
return search_result_dict
def save_hybrid_results(sparse_search_result_dict: dict, dense_search_result_dict: dict, colbert_search_result_dict: dict, hybrid_result_save_path: str, top_k: int=200, dense_weight: float=0.15, sparse_weight: float=0.5, colbert_weight: float=0.35):
if not os.path.exists(os.path.dirname(hybrid_result_save_path)):
os.makedirs(os.path.dirname(hybrid_result_save_path))
qid_list = list(set(sparse_search_result_dict.keys()) | set(dense_search_result_dict.keys() | set(colbert_search_result_dict.keys())))
hybrid_results_list = []
for qid in tqdm(qid_list, desc="Hybriding dense, sparse and colbert scores"):
results = {}
if qid in sparse_search_result_dict:
for docid, score in sparse_search_result_dict[qid]:
results[docid] = score * sparse_weight
if qid in dense_search_result_dict:
for docid, score in dense_search_result_dict[qid]:
if docid in results:
results[docid] = results[docid] + score * dense_weight
else:
results[docid] = score * dense_weight
if qid in colbert_search_result_dict:
for docid, score in colbert_search_result_dict[qid]:
if docid in results:
results[docid] = results[docid] + score * colbert_weight
else:
results[docid] = score * colbert_weight
hybrid_results = [(docid, score) for docid, score in results.items()]
hybrid_results.sort(key=lambda x: x[1], reverse=True)
hybrid_results_list.append(hybrid_results[:top_k])
with open(hybrid_result_save_path, 'w', encoding='utf-8') as f:
for qid, hybrid_results in tqdm(zip(qid_list, hybrid_results_list), desc="Saving hybrid search results"):
for rank, docid_score in enumerate(hybrid_results):
docid, score = docid_score
line = f"{qid} Q0 {docid} {rank+1} {score:.6f} Faiss-&-Anserini"
f.write(line + '\n')
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if os.path.basename(eval_args.encoder).startswith('checkpoint-'):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
if os.path.basename(eval_args.reranker).startswith('checkpoint-'):
eval_args.reranker = os.path.dirname(eval_args.reranker) + '_' + os.path.basename(eval_args.reranker)
dir_name = f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}"
for lang in languages:
print("**************************************************")
print(f"Start hybrid search results of {lang} ...")
hybrid_result_save_path = os.path.join(eval_args.hybrid_result_save_dir, dir_name, f"{lang}.txt")
sparse_search_result_save_dir = os.path.join(eval_args.sparse_search_result_save_dir, dir_name)
sparse_search_result_path = os.path.join(sparse_search_result_save_dir, f"{lang}.txt")
sparse_search_result_dict = get_search_result_dict(sparse_search_result_path, top_k=eval_args.top_k)
dense_search_result_save_dir = os.path.join(eval_args.dense_search_result_save_dir, dir_name)
dense_search_result_path = os.path.join(dense_search_result_save_dir, f"{lang}.txt")
dense_search_result_dict = get_search_result_dict(dense_search_result_path, top_k=eval_args.top_k)
colbert_search_result_save_dir = os.path.join(eval_args.colbert_search_result_save_dir, dir_name)
colbert_search_result_path = os.path.join(colbert_search_result_save_dir, f"{lang}.txt")
colbert_search_result_dict = get_search_result_dict(colbert_search_result_path, top_k=eval_args.top_k)
save_hybrid_results(
sparse_search_result_dict=sparse_search_result_dict,
dense_search_result_dict=dense_search_result_dict,
colbert_search_result_dict=colbert_search_result_dict,
hybrid_result_save_path=hybrid_result_save_path,
top_k=eval_args.top_k,
sparse_weight=eval_args.sparse_weight,
dense_weight=eval_args.dense_weight,
colbert_weight=eval_args.colbert_weight
)
print("==================================================")
print("Finish generating reranked results with following model and reranker:")
print(eval_args.encoder)
print(eval_args.reranker)
if __name__ == "__main__":
main()
"""
python step0-rerank_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ../dense_retrieval/search_results \
--rerank_result_save_dir ./rerank_results \
--top_k 200 \
--batch_size 4 \
--max_query_length 512 \
--max_passage_length 8192 \
--pooling_method cls \
--normalize_embeddings True \
--dense_weight 0.15 --sparse_weight 0.5 --colbert_weight 0.35 \
--num_shards 1 --shard_id 0 --cuda_id 0
"""
import os
import copy
import datasets
import pandas as pd
from tqdm import tqdm
from FlagEmbedding import BGEM3FlagModel
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class ModelArgs:
reranker: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of reranker'}
)
fp16: bool = field(
default=True,
metadata={'help': 'Use fp16 in inference?'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
max_query_length: int = field(
default=512,
metadata={'help': 'Max text length.'}
)
max_passage_length: int = field(
default=8192,
metadata={'help': 'Max text length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
top_k: int = field(
default=100,
metadata={'help': 'Use reranker to rerank top-k retrieval results'}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
search_result_save_dir: str = field(
default='./output_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}/{lang}.txt`'}
)
rerank_result_save_dir: str = field(
default='./rerank_results',
metadata={'help': 'Dir to saving reranked results. Reranked results will be saved to `rerank_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
num_shards: int = field(
default=1,
metadata={'help': "num of shards"}
)
shard_id: int = field(
default=0,
metadata={'help': 'id of shard, start from 0'}
)
cuda_id: int = field(
default=0,
metadata={'help': 'CUDA ID to use. -1 means only use CPU.'}
)
dense_weight: float = field(
default=0.15,
metadata={'help': 'The weight of dense score when hybriding all scores'}
)
sparse_weight: float = field(
default=0.5,
metadata={'help': 'The weight of sparse score when hybriding all scores'}
)
colbert_weight: float = field(
default=0.35,
metadata={'help': 'The weight of colbert score when hybriding all scores'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_reranker(model_args: ModelArgs, device: str=None):
reranker = BGEM3FlagModel(
model_name_or_path=model_args.reranker,
pooling_method=model_args.pooling_method,
normalize_embeddings=model_args.normalize_embeddings,
device=device
)
return reranker
def get_search_result_dict(search_result_path: str, top_k: int=200):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = row.iloc[2]
rank = int(row.iloc[3])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append(docid)
return search_result_dict
def get_queries_dict(lang: str, split: str='test'):
dataset = datasets.load_dataset('Shitao/MLDR', lang, split=split)
queries_dict = {}
for data in dataset:
qid = data['query_id']
query = data['query']
queries_dict[qid] = query
return queries_dict
def get_corpus_dict(lang: str):
corpus = datasets.load_dataset('Shitao/MLDR', f'corpus-{lang}', split='corpus')
corpus_dict = {}
for data in tqdm(corpus, desc="Generating corpus"):
docid = data['docid']
content = data['text']
corpus_dict[docid] = content
return corpus_dict
def save_rerank_results(queries_dict: dict, corpus_dict: dict, reranker: BGEM3FlagModel, search_result_dict: dict, rerank_result_save_path: dict, batch_size: int=256, max_query_length: int=512, max_passage_length: int=512, dense_weight: float=0.15, sparse_weight: float=0.5, colbert_weight: float=0.35):
qid_list = []
sentence_pairs = []
for qid, docids in search_result_dict.items():
qid_list.append(qid)
query = queries_dict[qid]
for docid in docids:
passage = corpus_dict[docid]
sentence_pairs.append((query, passage))
scores_dict = reranker.compute_score(
sentence_pairs,
batch_size=batch_size,
max_query_length=max_query_length,
max_passage_length=max_passage_length,
weights_for_different_modes=[dense_weight, sparse_weight, colbert_weight]
)
for sub_dir, _rerank_result_save_path in rerank_result_save_path.items():
if not os.path.exists(os.path.dirname(_rerank_result_save_path)):
os.makedirs(os.path.dirname(_rerank_result_save_path))
scores = scores_dict[sub_dir]
with open(_rerank_result_save_path, 'w', encoding='utf-8') as f:
i = 0
for qid in qid_list:
docids = search_result_dict[qid]
docids_scores = []
for j in range(len(docids)):
docids_scores.append((docids[j], scores[i + j]))
i += len(docids)
docids_scores.sort(key=lambda x: x[1], reverse=True)
for rank, docid_score in enumerate(docids_scores):
docid, score = docid_score
line = f"{qid} Q0 {docid} {rank+1} {score:.6f} Faiss"
f.write(line + '\n')
def get_shard(search_result_dict: dict, num_shards: int, shard_id: int):
if num_shards <= 1:
return search_result_dict
keys_list = sorted(list(search_result_dict.keys()))
shard_len = len(keys_list) // num_shards
if shard_id == num_shards - 1:
shard_keys_list = keys_list[shard_id*shard_len:]
else:
shard_keys_list = keys_list[shard_id*shard_len : (shard_id + 1)*shard_len]
shard_search_result_dict = {k: search_result_dict[k] for k in shard_keys_list}
return shard_search_result_dict
def rerank_results(languages: list, eval_args: EvalArgs, model_args: ModelArgs, device: str=None):
eval_args = copy.deepcopy(eval_args)
model_args = copy.deepcopy(model_args)
num_shards = eval_args.num_shards
shard_id = eval_args.shard_id
if shard_id >= num_shards:
raise ValueError(f"shard_id >= num_shards ({shard_id} >= {num_shards})")
reranker = get_reranker(model_args=model_args, device=device)
if os.path.basename(eval_args.encoder).startswith('checkpoint-'):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
if os.path.basename(model_args.reranker).startswith('checkpoint-'):
model_args.reranker = os.path.dirname(model_args.reranker) + '_' + os.path.basename(model_args.reranker)
for lang in languages:
print("**************************************************")
print(f"Start reranking results of {lang} ...")
queries_dict = get_queries_dict(lang, split='test')
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
search_result_dict = get_search_result_dict(search_result_path, top_k=eval_args.top_k)
search_result_dict = get_shard(search_result_dict, num_shards=num_shards, shard_id=shard_id)
corpus_dict = get_corpus_dict(lang)
rerank_result_save_path = {}
for sub_dir in ['colbert', 'sparse', 'dense', 'colbert+sparse+dense']:
_rerank_result_save_path = os.path.join(
eval_args.rerank_result_save_dir,
sub_dir,
f"{os.path.basename(eval_args.encoder)}-{os.path.basename(model_args.reranker)}",
f"{lang}_{shard_id}-of-{num_shards}.txt" if num_shards > 1 else f"{lang}.txt"
)
rerank_result_save_path[sub_dir] = _rerank_result_save_path
save_rerank_results(
queries_dict=queries_dict,
corpus_dict=corpus_dict,
reranker=reranker,
search_result_dict=search_result_dict,
rerank_result_save_path=rerank_result_save_path,
batch_size=eval_args.batch_size,
max_query_length=eval_args.max_query_length,
max_passage_length=eval_args.max_passage_length,
dense_weight=eval_args.dense_weight,
sparse_weight=eval_args.sparse_weight,
colbert_weight=eval_args.colbert_weight
)
def main():
parser = HfArgumentParser([EvalArgs, ModelArgs])
eval_args, model_args = parser.parse_args_into_dataclasses()
eval_args: EvalArgs
model_args: ModelArgs
languages = check_languages(eval_args.languages)
cuda_id = eval_args.cuda_id
if cuda_id < 0:
rerank_results(languages, eval_args, model_args, device='cpu')
else:
rerank_results(languages, eval_args, model_args, device=f"cuda:{cuda_id}")
print("==================================================")
print("Finish generating reranked results with following encoder and reranker:")
print(eval_args.encoder)
print(model_args.reranker)
if __name__ == "__main__":
main()
"""
# 1. Rerank Search Results
python step0-rerank_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ../dense_retrieval/search_results \
--rerank_result_save_dir ./rerank_results \
--top_k 200 \
--batch_size 4 \
--max_query_length 512 \
--max_passage_length 8192 \
--pooling_method cls \
--normalize_embeddings True \
--dense_weight 0.15 --sparse_weight 0.5 --colbert_weight 0.35 \
--num_shards 1 --shard_id 0 --cuda_id 0
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mldr.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--search_result_save_dir ./rerank_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
"""
import os
import json
import platform
import subprocess
import numpy as np
from pprint import pprint
from collections import defaultdict
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from pyserini.util import download_evaluation_script
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
reranker: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of reranker'}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
search_result_save_dir: str = field(
default='./rerank_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
qrels_dir: str = field(
default='../qrels',
metadata={'help': 'Dir to topics and qrels.'}
)
metrics: str = field(
default="ndcg@10",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: ndcg@k, recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./reranker_evaluation_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{encoder}-{reranker}.json`'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, reranker_name: str, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'reranker': reranker_name,
'model': model_name,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{reranker_name}` on `{eval_languages}` based on `{model_name}` saved at `{save_path}`')
def map_metric(metric: str):
metric, k = metric.split('@')
if metric.lower() == 'ndcg':
return k, f'ndcg_cut.{k}'
elif metric.lower() == 'recall':
return k, f'recall.{k}'
else:
raise ValueError(f"Unkown metric: {metric}")
def evaluate(script_path: str, qrels_path, search_result_path, metrics: list):
cmd_prefix = ['java', '-jar', script_path]
results = {}
for metric in metrics:
k, mapped_metric = map_metric(metric)
args = ['-c', '-M', str(k), '-m', mapped_metric, qrels_path, search_result_path]
cmd = cmd_prefix + args
# print(f'Running command: {cmd}')
shell = platform.system() == "Windows"
process = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=shell)
stdout, stderr = process.communicate()
if stderr:
print(stderr.decode("utf-8"))
result_str = stdout.decode("utf-8")
try:
results[metric] = float(result_str.split(' ')[-1].split('\t')[-1])
except:
results[metric] = result_str
return results
def merge_search_result(search_result_save_dir: str, lang: str):
lang_files = [file for file in os.listdir(search_result_save_dir) if f'{lang}_' in file]
shard_info_dict = defaultdict(set)
for file in lang_files:
file_name = file.split('.')[0]
shard_info = file_name.split('_')[1]
shard_id, num_shards = int(shard_info.split('-')[0]), int(shard_info.split('-')[2])
assert shard_id < num_shards
shard_info_dict[num_shards].add(shard_id)
flag = False
for num_shards, shard_ids in shard_info_dict.items():
if len(shard_ids) != num_shards:
flag = False
else:
flag = True
lang_paths = os.path.join(search_result_save_dir, f'{lang}_*-of-{num_shards}.txt')
save_path = os.path.join(search_result_save_dir, f'{lang}.txt')
cmd = f'cat {lang_paths} > {save_path}'
os.system(cmd)
break
if not flag:
raise ValueError(f"Fail to find complete search results of {lang} in {search_result_save_dir}")
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
script_path = download_evaluation_script('trec_eval')
languages = check_languages(eval_args.languages)
if 'checkpoint-' in os.path.basename(eval_args.encoder):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
if 'checkpoint-' in os.path.basename(eval_args.reranker):
eval_args.reranker = os.path.dirname(eval_args.reranker) + '_' + os.path.basename(eval_args.reranker)
try:
for sub_dir in ['colbert', 'sparse', 'dense', 'colbert+sparse+dense']:
results = {}
for lang in languages:
qrels_path = os.path.join(eval_args.qrels_dir, f"qrels.mldr-v1.0-{lang}-test.tsv")
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, sub_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}")
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
if not os.path.exists(search_result_path):
merge_search_result(search_result_save_dir, lang)
assert os.path.exists(search_result_path)
result = evaluate(script_path, qrels_path, search_result_path, eval_args.metrics)
results[lang] = result
print("****************************")
print(sub_dir + ":")
save_results(
model_name=eval_args.encoder,
reranker_name=eval_args.reranker,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, sub_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}.json"),
eval_languages=languages,
)
except:
results = {}
for lang in languages:
qrels_path = os.path.join(eval_args.qrels_dir, f"qrels.mldr-v1.0-{lang}-test.tsv")
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}")
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
if not os.path.exists(search_result_path):
merge_search_result(search_result_save_dir, lang)
assert os.path.exists(search_result_path)
result = evaluate(script_path, qrels_path, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=eval_args.encoder,
reranker_name=eval_args.reranker,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}.json"),
eval_languages=languages,
)
print("==================================================")
print("Finish generating evaluation results with following model and reranker:")
print(eval_args.encoder)
print(eval_args.reranker)
if __name__ == "__main__":
main()
"""
# 1. Output Search Results with BM25
python bm25_baseline.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mldr.py \
--encoder bm25 \
--languages ar de es fr hi it ja ko pt ru th en zh \
--search_result_save_dir ./search_results \
--qrels_dir ../qrels \
--eval_result_save_dir ./eval_results \
--metrics ndcg@10
"""
import os
import datasets
from tqdm import tqdm
def generate_corpus(lang: str, corpus_save_dir: str):
corpus_save_path = os.path.join(corpus_save_dir, 'corpus.jsonl')
if os.path.exists(corpus_save_path):
return
corpus = datasets.load_dataset('Shitao/MLDR', f'corpus-{lang}', split='corpus')
corpus_list = [{'id': e['docid'], 'contents': e['text']} for e in tqdm(corpus, desc="Generating corpus")]
corpus = datasets.Dataset.from_list(corpus_list)
corpus.to_json(corpus_save_path, force_ascii=False)
def generate_queries(lang: str, queries_save_dir: str, split: str='test'):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
if os.path.exists(queries_save_path):
return
dataset = datasets.load_dataset('Shitao/MLDR', lang, split=split)
queries_list = []
for data in dataset:
queries_list.append({
'id': data['query_id'],
'content': data['query'].replace('\n', ' ').replace('\t', ' ')
})
with open(queries_save_path, 'w', encoding='utf-8') as f:
for query in queries_list:
assert '\n' not in query['content'] and '\t' not in query['content']
line = f"{query['id']}\t{query['content']}"
f.write(line + '\n')
def index(lang: str, corpus_save_dir: str, index_save_dir: str):
cmd = f"python -m pyserini.index.lucene \
--language {lang} \
--collection JsonCollection \
--input {corpus_save_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads 1 --optimize \
"
os.system(cmd)
def search(index_save_dir: str, queries_save_dir: str, lang: str, result_save_path: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
cmd = f"python -m pyserini.search.lucene \
--language {lang} \
--index {index_save_dir} \
--topics {queries_save_path} \
--output {result_save_path} \
--bm25 \
--hits 1000 \
--batch-size 128 \
--threads 16 \
"
os.system(cmd)
def main():
bm25_dir = './bm25_baseline'
result_save_dir = os.path.join('./search_results', 'bm25')
if not os.path.exists(result_save_dir):
os.makedirs(result_save_dir)
for lang in ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']:
save_dir = os.path.join(bm25_dir, lang)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
corpus_save_dir = os.path.join(save_dir, 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
generate_corpus(lang, corpus_save_dir)
index_save_dir = os.path.join(save_dir, 'index')
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
index(lang, corpus_save_dir, index_save_dir)
generate_queries(lang, save_dir, split='test')
result_save_path = os.path.join(result_save_dir, f'{lang}.txt')
search(index_save_dir, save_dir, lang, result_save_path)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment