Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
File added
*.memmap
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.idea/
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
Untitled.ipynb
try.py
update_model_card.py
model_card.md
pic.py
pic2.py
# Pyre type checker
.pyre/
from .tasks import *
ChineseTaskList = ['TNews', 'IFlyTek', 'MultilingualSentiment', 'JDReview', 'OnlineShopping', 'Waimai',
'CLSClusteringS2S', 'CLSClusteringP2P', 'ThuNewsClusteringS2S', 'ThuNewsClusteringP2P',
'Ocnli', 'Cmnli',
'T2Reranking', 'MMarcoReranking', 'CMedQAv1', 'CMedQAv2',
'T2Retrieval', 'MMarcoRetrieval', 'DuRetrieval', 'CovidRetrieval', 'CmedqaRetrieval', 'EcomRetrieval', 'MedicalRetrieval', 'VideoRetrieval',
'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'AFQMC', 'QBQTC']
from mteb import AbsTaskClassification
class TNews(AbsTaskClassification):
@property
def description(self):
return {
'name': 'TNews',
'hf_hub_name': 'C-MTEB/TNews-classification',
'description': 'Short Text Classification for News',
"reference": "https://www.cluebenchmarks.com/introduce.html",
'type': 'Classification',
'category': 's2s',
'eval_splits': ['validation'],
'eval_langs': ['zh'],
'main_score': 'accuracy',
'samples_per_label': 32,
}
class IFlyTek(AbsTaskClassification):
@property
def description(self):
return {
'name': 'IFlyTek',
'hf_hub_name': 'C-MTEB/IFlyTek-classification',
'description': 'Long Text classification for the description of Apps',
"reference": "https://www.cluebenchmarks.com/introduce.html",
'type': 'Classification',
'category': 's2s',
'eval_splits': ['validation'],
'eval_langs': ['zh'],
'main_score': 'accuracy',
'samples_per_label': 32,
'n_experiments': 5
}
class MultilingualSentiment(AbsTaskClassification):
@property
def description(self):
return {
'name': 'MultilingualSentiment',
'hf_hub_name': 'C-MTEB/MultilingualSentiment-classification',
'description': 'A collection of multilingual sentiments datasets grouped into 3 classes -- positive, neutral, negative',
"reference": "https://github.com/tyqiangz/multilingual-sentiment-datasets",
'category': 's2s',
'type': 'Classification',
'eval_splits': ['validation'],
'eval_langs': ['zh'],
'main_score': 'accuracy',
'samples_per_label': 32,
}
class JDReview(AbsTaskClassification):
@property
def description(self):
return {
'name': 'JDReview',
'hf_hub_name': 'C-MTEB/JDReview-classification',
'description': 'review for iphone',
'category': 's2s',
'type': 'Classification',
'eval_splits': ['test'],
'eval_langs': ['zh'],
'main_score': 'accuracy',
'samples_per_label': 32,
}
class OnlineShopping(AbsTaskClassification):
@property
def description(self):
return {
'name': 'OnlineShopping',
'hf_hub_name': 'C-MTEB/OnlineShopping-classification',
'description': 'Sentiment Analysis of User Reviews on Online Shopping Websites',
'category': 's2s',
'type': 'Classification',
'eval_splits': ['test'],
'eval_langs': ['zh'],
'main_score': 'accuracy',
'samples_per_label': 32,
}
class Waimai(AbsTaskClassification):
@property
def description(self):
return {
'name': 'Waimai',
'hf_hub_name': 'C-MTEB/waimai-classification',
'description': 'Sentiment Analysis of user reviews on takeaway platforms',
'category': 's2s',
'type': 'Classification',
'eval_splits': ['test'],
'eval_langs': ['zh'],
'main_score': 'accuracy',
'samples_per_label': 32,
}
\ No newline at end of file
from mteb import AbsTaskClustering
class CLSClusteringS2S(AbsTaskClustering):
@property
def description(self):
return {
"name": "CLSClusteringS2S",
"hf_hub_name": "C-MTEB/CLSClusteringS2S",
"description": (
"Clustering of titles from CLS dataset. Clustering of 13 sets, based on the main category."
),
"reference": "https://arxiv.org/abs/2209.05034",
"type": "Clustering",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "v_measure",
}
class CLSClusteringP2P(AbsTaskClustering):
@property
def description(self):
return {
"name": "CLSClusteringP2P",
"hf_hub_name": "C-MTEB/CLSClusteringP2P",
"description": (
"Clustering of titles + abstract from CLS dataset. Clustering of 13 sets, based on the main category."
),
"reference": "https://arxiv.org/abs/2209.05034",
"type": "Clustering",
"category": "p2p",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "v_measure",
}
class ThuNewsClusteringS2S(AbsTaskClustering):
@property
def description(self):
return {
'name': 'ThuNewsClusteringS2S',
'hf_hub_name': 'C-MTEB/ThuNewsClusteringS2S',
'description': 'Clustering of titles from the THUCNews dataset',
"reference": "http://thuctc.thunlp.org/",
"type": "Clustering",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "v_measure",
}
class ThuNewsClusteringP2P(AbsTaskClustering):
@property
def description(self):
return {
'name': 'ThuNewsClusteringP2P',
'hf_hub_name': 'C-MTEB/ThuNewsClusteringP2P',
'description': 'Clustering of titles + abstracts from the THUCNews dataset',
"reference": "http://thuctc.thunlp.org/",
"type": "Clustering",
"category": "p2p",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "v_measure",
}
import datasets
from mteb.abstasks import MultilingualTask, AbsTaskRetrieval
from mteb.abstasks.AbsTaskRetrieval import *
# from ...abstasks import MultilingualTask, AbsTaskRetrieval
# from ...abstasks.AbsTaskRetrieval import *
_LANGUAGES = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
def load_mldr_data(path: str, langs: list, eval_splits: list, cache_dir: str=None):
corpus = {lang: {split: None for split in eval_splits} for lang in langs}
queries = {lang: {split: None for split in eval_splits} for lang in langs}
relevant_docs = {lang: {split: None for split in eval_splits} for lang in langs}
for lang in langs:
lang_corpus = datasets.load_dataset(path, f'corpus-{lang}', cache_dir=cache_dir)['corpus']
lang_corpus = {e['docid']: {'text': e['text']} for e in lang_corpus}
lang_data = datasets.load_dataset(path, lang, cache_dir=cache_dir)
for split in eval_splits:
corpus[lang][split] = lang_corpus
queries[lang][split] = {e['query_id']: e['query'] for e in lang_data[split]}
relevant_docs[lang][split] = {e['query_id']: {e['positive_passages'][0]['docid']: 1} for e in lang_data[split]}
corpus = datasets.DatasetDict(corpus)
queries = datasets.DatasetDict(queries)
relevant_docs = datasets.DatasetDict(relevant_docs)
return corpus, queries, relevant_docs
class MultiLongDocRetrieval(MultilingualTask, AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'MultiLongDocRetrieval',
'hf_hub_name': 'Shitao/MLDR',
'reference': 'https://arxiv.org/abs/2402.03216',
'description': 'MultiLongDocRetrieval: A Multilingual Long-Document Retrieval Dataset',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev', 'test'],
'eval_langs': _LANGUAGES,
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_mldr_data(
path=self.description['hf_hub_name'],
langs=self.langs,
eval_splits=self.description['eval_splits'],
cache_dir=kwargs.get('cache_dir', None)
)
self.data_loaded = True
def evaluate(
self,
model,
split="test",
batch_size=128,
corpus_chunk_size=None,
score_function="cos_sim",
**kwargs
):
try:
from beir.retrieval.evaluation import EvaluateRetrieval
except ImportError:
raise Exception("Retrieval tasks require beir package. Please install it with `pip install mteb[beir]`")
if not self.data_loaded:
self.load_data()
model = model if self.is_dres_compatible(model) else DRESModel(model)
if os.getenv("RANK", None) is None:
# Non-distributed
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
model = DRES(
model,
batch_size=batch_size,
corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
**kwargs,
)
else:
# Distributed (multi-GPU)
from beir.retrieval.search.dense import (
DenseRetrievalParallelExactSearch as DRPES,
)
model = DRPES(
model,
batch_size=batch_size,
corpus_chunk_size=corpus_chunk_size,
**kwargs,
)
retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot"
scores = {}
for lang in self.langs:
print(f"==============================\nStart evaluating {lang} ...")
corpus, queries, relevant_docs = self.corpus[lang][split], self.queries[lang][split], self.relevant_docs[lang][split]
start_time = time()
results = retriever.retrieve(corpus, queries)
end_time = time()
logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values, ignore_identical_ids=kwargs.get("ignore_identical_ids", True))
mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")
scores[lang] = {
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
}
return scores
from mteb import AbsTaskPairClassification
class Ocnli(AbsTaskPairClassification):
@property
def description(self):
return {
'name': 'Ocnli',
"hf_hub_name": "C-MTEB/OCNLI",
'description': 'Original Chinese Natural Language Inference dataset',
"reference": "https://arxiv.org/abs/2010.05444",
'category': 's2s',
'type': 'PairClassification',
'eval_splits': ['validation'],
'eval_langs': ['zh'],
'main_score': 'ap',
}
class Cmnli(AbsTaskPairClassification):
@property
def description(self):
return {
'name': 'Cmnli',
"hf_hub_name": "C-MTEB/CMNLI",
'description': 'Chinese Multi-Genre NLI',
"reference": "https://huggingface.co/datasets/clue/viewer/cmnli",
'category': 's2s',
'type': 'PairClassification',
'eval_splits': ['validation'],
'eval_langs': ['zh'],
'main_score': 'ap',
}
import logging
import numpy as np
from mteb import RerankingEvaluator, AbsTaskReranking
from tqdm import tqdm
logger = logging.getLogger(__name__)
class ChineseRerankingEvaluator(RerankingEvaluator):
"""
This class evaluates a SentenceTransformer model for the task of re-ranking.
Given a query and a list of documents, it computes the score [query, doc_i] for all possible
documents and sorts them in decreasing order. Then, MRR@10 and MAP is compute to measure the quality of the ranking.
:param samples: Must be a list and each element is of the form:
- {'query': '', 'positive': [], 'negative': []}. Query is the search query, positive is a list of positive
(relevant) documents, negative is a list of negative (irrelevant) documents.
- {'query': [], 'positive': [], 'negative': []}. Where query is a list of strings, which embeddings we average
to get the query embedding.
"""
def compute_metrics_batched(self, model):
"""
Computes the metrices in a batched way, by batching all queries and
all documents together
"""
if hasattr(model, 'compute_score'):
return self.compute_metrics_batched_from_crossencoder(model)
else:
return self.compute_metrics_batched_from_biencoder(model)
def compute_metrics_batched_from_crossencoder(self, model):
all_mrr_scores = []
all_ap_scores = []
pairs = []
for sample in tqdm(self.samples, desc="Evaluating"):
for p in sample['positive']:
pairs.append([sample['query'], p])
for n in sample['negative']:
pairs.append([sample['query'], n])
all_scores = model.compute_score(pairs)
all_scores = np.array(all_scores)
start_inx = 0
for sample in tqdm(self.samples, desc="Evaluating"):
is_relevant = [True] * len(sample['positive']) + [False] * len(sample['negative'])
pred_scores = all_scores[start_inx:start_inx + len(is_relevant)]
start_inx += len(is_relevant)
pred_scores_argsort = np.argsort(-pred_scores) # Sort in decreasing order
mrr = self.mrr_at_k_score(is_relevant, pred_scores_argsort, self.mrr_at_k)
ap = self.ap_score(is_relevant, pred_scores)
all_mrr_scores.append(mrr)
all_ap_scores.append(ap)
mean_ap = np.mean(all_ap_scores)
mean_mrr = np.mean(all_mrr_scores)
return {"map": mean_ap, "mrr": mean_mrr}
def compute_metrics_batched_from_biencoder(self, model):
all_mrr_scores = []
all_ap_scores = []
logger.info("Encoding queries...")
if isinstance(self.samples[0]["query"], str):
if hasattr(model, 'encode_queries'):
all_query_embs = model.encode_queries(
[sample["query"] for sample in self.samples],
convert_to_tensor=True,
batch_size=self.batch_size,
)
else:
all_query_embs = model.encode(
[sample["query"] for sample in self.samples],
convert_to_tensor=True,
batch_size=self.batch_size,
)
elif isinstance(self.samples[0]["query"], list):
# In case the query is a list of strings, we get the most similar embedding to any of the queries
all_query_flattened = [q for sample in self.samples for q in sample["query"]]
if hasattr(model, 'encode_queries'):
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
batch_size=self.batch_size)
else:
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
else:
raise ValueError(f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}")
logger.info("Encoding candidates...")
all_docs = []
for sample in self.samples:
all_docs.extend(sample["positive"])
all_docs.extend(sample["negative"])
all_docs_embs = model.encode(all_docs, convert_to_tensor=True, batch_size=self.batch_size)
# Compute scores
logger.info("Evaluating...")
query_idx, docs_idx = 0, 0
for instance in self.samples:
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
query_idx += num_subqueries
num_pos = len(instance["positive"])
num_neg = len(instance["negative"])
docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
docs_idx += num_pos + num_neg
if num_pos == 0 or num_neg == 0:
continue
is_relevant = [True] * num_pos + [False] * num_neg
scores = self._compute_metrics_instance(query_emb, docs_emb, is_relevant)
all_mrr_scores.append(scores["mrr"])
all_ap_scores.append(scores["ap"])
mean_ap = np.mean(all_ap_scores)
mean_mrr = np.mean(all_mrr_scores)
return {"map": mean_ap, "mrr": mean_mrr}
def evaluate(self, model, split="test", **kwargs):
if not self.data_loaded:
self.load_data()
data_split = self.dataset[split]
evaluator = ChineseRerankingEvaluator(data_split, **kwargs)
scores = evaluator(model)
return dict(scores)
AbsTaskReranking.evaluate = evaluate
class T2Reranking(AbsTaskReranking):
@property
def description(self):
return {
'name': 'T2Reranking',
'hf_hub_name': "C-MTEB/T2Reranking",
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
"reference": "https://arxiv.org/abs/2304.03679",
'type': 'Reranking',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'map',
}
class T2RerankingZh2En(AbsTaskReranking):
@property
def description(self):
return {
'name': 'T2RerankingZh2En',
'hf_hub_name': "C-MTEB/T2Reranking_zh2en",
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
"reference": "https://arxiv.org/abs/2304.03679",
'type': 'Reranking',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh2en'],
'main_score': 'map',
}
class T2RerankingEn2Zh(AbsTaskReranking):
@property
def description(self):
return {
'name': 'T2RerankingEn2Zh',
'hf_hub_name': "C-MTEB/T2Reranking_en2zh",
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
"reference": "https://arxiv.org/abs/2304.03679",
'type': 'Reranking',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['en2zh'],
'main_score': 'map',
}
class MMarcoReranking(AbsTaskReranking):
@property
def description(self):
return {
'name': 'MMarcoReranking',
'hf_hub_name': "C-MTEB/Mmarco-reranking",
'description': 'mMARCO is a multilingual version of the MS MARCO passage ranking dataset',
"reference": "https://github.com/unicamp-dl/mMARCO",
'type': 'Reranking',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'map',
}
class CMedQAv1(AbsTaskReranking):
@property
def description(self):
return {
'name': 'CMedQAv1',
"hf_hub_name": "C-MTEB/CMedQAv1-reranking",
'description': 'Chinese community medical question answering',
"reference": "https://github.com/zhangsheng93/cMedQA",
'type': 'Reranking',
'category': 's2p',
'eval_splits': ['test'],
'eval_langs': ['zh'],
'main_score': 'map',
}
class CMedQAv2(AbsTaskReranking):
@property
def description(self):
return {
'name': 'CMedQAv2',
"hf_hub_name": "C-MTEB/CMedQAv2-reranking",
'description': 'Chinese community medical question answering',
"reference": "https://github.com/zhangsheng93/cMedQA2",
'type': 'Reranking',
'category': 's2p',
'eval_splits': ['test'],
'eval_langs': ['zh'],
'main_score': 'map',
}
from collections import defaultdict
from datasets import load_dataset, DatasetDict
from mteb import AbsTaskRetrieval
def load_retrieval_data(hf_hub_name, eval_splits):
eval_split = eval_splits[0]
dataset = load_dataset(hf_hub_name)
qrels = load_dataset(hf_hub_name + '-qrels')[eval_split]
corpus = {e['id']: {'text': e['text']} for e in dataset['corpus']}
queries = {e['id']: e['text'] for e in dataset['queries']}
relevant_docs = defaultdict(dict)
for e in qrels:
relevant_docs[e['qid']][e['pid']] = e['score']
corpus = DatasetDict({eval_split:corpus})
queries = DatasetDict({eval_split:queries})
relevant_docs = DatasetDict({eval_split:relevant_docs})
return corpus, queries, relevant_docs
class T2Retrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'T2Retrieval',
'hf_hub_name': 'C-MTEB/T2Retrieval',
'reference': 'https://arxiv.org/abs/2304.03679',
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
class MMarcoRetrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'MMarcoRetrieval',
'hf_hub_name': 'C-MTEB/MMarcoRetrieval',
'reference': 'https://github.com/unicamp-dl/mMARCO',
'description': 'mMARCO is a multilingual version of the MS MARCO passage ranking dataset',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
class DuRetrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'DuRetrieval',
'hf_hub_name': 'C-MTEB/DuRetrieval',
'reference': 'https://aclanthology.org/2022.emnlp-main.357.pdf',
'description': 'A Large-scale Chinese Benchmark for Passage Retrieval from Web Search Engine',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
class CovidRetrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'CovidRetrieval',
'hf_hub_name': 'C-MTEB/CovidRetrieval',
'reference': 'https://aclanthology.org/2022.emnlp-main.357.pdf',
'description': 'COVID-19 news articles',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
class CmedqaRetrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'CmedqaRetrieval',
'hf_hub_name': 'C-MTEB/CmedqaRetrieval',
'reference': 'https://aclanthology.org/2022.emnlp-main.357.pdf',
'description': 'Online medical consultation text',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
class EcomRetrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'EcomRetrieval',
'hf_hub_name': 'C-MTEB/EcomRetrieval',
'reference': 'https://arxiv.org/abs/2203.03367',
'description': 'Passage retrieval dataset collected from Alibaba search engine systems in ecom domain',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
class MedicalRetrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'MedicalRetrieval',
'hf_hub_name': 'C-MTEB/MedicalRetrieval',
'reference': 'https://arxiv.org/abs/2203.03367',
'description': 'Passage retrieval dataset collected from Alibaba search engine systems in medical domain',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'],
self.description['eval_splits'])
self.data_loaded = True
class VideoRetrieval(AbsTaskRetrieval):
@property
def description(self):
return {
'name': 'VideoRetrieval',
'hf_hub_name': 'C-MTEB/VideoRetrieval',
'reference': 'https://arxiv.org/abs/2203.03367',
'description': 'Passage retrieval dataset collected from Alibaba search engine systems in video domain',
'type': 'Retrieval',
'category': 's2p',
'eval_splits': ['dev'],
'eval_langs': ['zh'],
'main_score': 'ndcg_at_10',
}
def load_data(self, **kwargs):
if self.data_loaded:
return
self.corpus, self.queries, self.relevant_docs = load_retrieval_data(self.description['hf_hub_name'], self.description['eval_splits'])
self.data_loaded = True
from mteb import AbsTaskSTS
class ATEC(AbsTaskSTS):
@property
def description(self):
return {
"name": "ATEC",
"hf_hub_name": "C-MTEB/ATEC",
"type": "STS",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "cosine_spearman",
"min_score": 0,
"max_score": 1,
}
class BQ(AbsTaskSTS):
@property
def description(self):
return {
"name": "BQ",
"hf_hub_name": "C-MTEB/BQ",
"type": "STS",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "cosine_spearman",
"min_score": 0,
"max_score": 1,
}
class LCQMC(AbsTaskSTS):
@property
def description(self):
return {
"name": "LCQMC",
"hf_hub_name": "C-MTEB/LCQMC",
"type": "STS",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "cosine_spearman",
"min_score": 0,
"max_score": 1,
}
class PAWSX(AbsTaskSTS):
@property
def description(self):
return {
"name": "PAWSX",
"hf_hub_name": "C-MTEB/PAWSX",
"type": "STS",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "cosine_spearman",
"min_score": 0,
"max_score": 1,
}
class STSB(AbsTaskSTS):
@property
def description(self):
return {
"name": "STSB",
"hf_hub_name": "C-MTEB/STSB",
"type": "STS",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "cosine_spearman",
"min_score": 0,
"max_score": 5,
}
class AFQMC(AbsTaskSTS):
@property
def description(self):
return {
"name": "AFQMC",
"hf_hub_name": "C-MTEB/AFQMC",
"type": "STS",
"category": "s2s",
"eval_splits": ["validation"],
"eval_langs": ["zh"],
"main_score": "cosine_spearman",
"min_score": 0,
"max_score": 1,
}
class QBQTC(AbsTaskSTS):
@property
def description(self):
return {
"name": "QBQTC",
"hf_hub_name": "C-MTEB/QBQTC",
"reference": "https://github.com/CLUEbenchmark/QBQTC/tree/main/dataset",
"type": "STS",
"category": "s2s",
"eval_splits": ["test"],
"eval_langs": ["zh"],
"main_score": "cosine_spearman",
"min_score": 0,
"max_score": 1,
}
from .Classification import *
from .Clustering import *
from .PairClassification import *
from .Reranking import *
from .Retrieval import *
from .STS import *
# MKQA
MKQA is a cross-lingual question answering dataset covering 25 non-English languages. For more details, please refer to [here](https://github.com/apple/ml-mkqa).
We filter questions which types are `unanswerable`, `binary` and `long-answer`. Finally we get 6,619 questions for every language. To perform evaluation, you should firstly **download the test data**:
```bash
# download
wget https://huggingface.co/datasets/Shitao/bge-m3-data/resolve/main/MKQA_test-data.zip
# unzip to `qa_data` dir
unzip MKQA_test-data.zip -d qa_data
```
We use the well-processed NQ [corpus](https://huggingface.co/datasets/BeIR/nq) offered by BEIR as the candidate, and perform evaluation with metrics: Recall@100 and Recall@20. Here the definition of Recall@k refers to [RocketQA](https://aclanthology.org/2021.naacl-main.466.pdf).
## Dense Retrieval
If you only want to perform dense retrieval with embedding models, you can follow the following steps:
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
```bash
# install java (Linux)
apt update
apt install openjdk-11-jdk
# install pyserini
pip install pyserini
# install faiss
## CPU version
conda install -c conda-forge faiss-cpu
## GPU version
conda install -c conda-forge faiss-gpu
```
2. Dense retrieval:
```bash
cd dense_retrieval
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--index_save_dir ./corpus-index \
--max_passage_length 512 \
--batch_size 256 \
--fp16 \
--add_instruction False \
--pooling_method cls \
--normalize_embeddings True
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--batch_size 32 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mkqa.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
```
There are some important parameters:
- `encoder`: Name or path of the model to evaluate.
- `languages`: The languages you want to evaluate on. Avaliable languages: `ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw`.
- `max_passage_length`: Maximum passage length when encoding.
- `batch_size`: Batch size for query and corpus when encoding. For faster evaluation, you should set the `batch_size` as large as possible.
- `pooling_method` & `normalize_embeddings`: You should follow the corresponding setting of the model you are evaluating. For example, `BAAI/bge-m3` is `cls` and `True`, `intfloat/multilingual-e5-large` is `mean` and `True`, and `intfloat/e5-mistral-7b-instruct` is `last` and `True`.
- `overwrite`: Whether to overwrite evaluation results.
## Hybrid Retrieval (Dense & Sparse)
If you want to perform **hybrid retrieval with both dense and sparse methods**, you can follow the following steps:
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
```bash
# install java (Linux)
apt update
apt install openjdk-11-jdk
# install pyserini
pip install pyserini
# install faiss
## CPU version
conda install -c conda-forge faiss-cpu
## GPU version
conda install -c conda-forge faiss-gpu
```
2. Dense retrieval:
```bash
cd dense_retrieval
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--index_save_dir ./corpus-index \
--max_passage_length 512 \
--batch_size 256 \
--fp16 \
--add_instruction False \
--pooling_method cls \
--normalize_embeddings True
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--batch_size 32 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mkqa.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
```
3. Sparse Retrieval
```bash
cd sparse_retrieval
# 1. Generate Query and Corpus Sparse Vector
python step0-encode_query-and-corpus.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--qa_data_dir ../qa_data \
--save_dir ./encoded_query-and-corpus \
--max_query_length 512 \
--max_passage_length 512 \
--batch_size 1024 \
--pooling_method cls \
--normalize_embeddings True
# 2. Output Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--hits 1000
# 3. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
```
4. Hybrid Retrieval
```bash
cd hybrid_retrieval
# 1. Search Dense and Sparse Results
Dense Retrieval
Sparse Retrieval
# 2. Hybrid Dense and Sparse Search Results
python step0-hybrid_search_results.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--dense_search_result_save_dir ../dense_retrieval/search_results \
--sparse_search_result_save_dir ../sparse_retrieval/search_results \
--hybrid_result_save_dir ./search_results \
--top_k 1000 \
--dense_weight 1 --sparse_weight 0.3 \
--threads 32
# 3. Print and Save Evaluation Results
python step1-eval_hybrid_mkqa.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
```
## MultiVector and All Rerank
If you want to perform **multi-vector reranking** or **all reranking** based on the search results of dense retrieval, you can follow the following steps:
1. Install Java, Pyserini and Faiss (CPU version or GPU version):
```bash
# install java (Linux)
apt update
apt install openjdk-11-jdk
# install pyserini
pip install pyserini
# install faiss
## CPU version
conda install -c conda-forge faiss-cpu
## GPU version
conda install -c conda-forge faiss-gpu
```
2. Dense retrieval:
```bash
cd dense_retrieval
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--index_save_dir ./corpus-index \
--max_passage_length 512 \
--batch_size 256 \
--fp16 \
--add_instruction False \
--pooling_method cls \
--normalize_embeddings True
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--batch_size 32 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mkqa.py \
--encoder BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
```
3. Rerank search results with multi-vector scores or all scores:
```bash
cd multi_vector_rerank
# 1. Rerank Search Results
python step0-rerank_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ../dense_retrieval/search_results \
--qa_data_dir ../qa_data \
--rerank_result_save_dir ./rerank_results \
--top_k 100 \
--batch_size 4 \
--max_length 512 \
--pooling_method cls \
--normalize_embeddings True \
--dense_weight 1 --sparse_weight 0.3 --colbert_weight 1 \
--num_shards 1 --shard_id 0 --cuda_id 0
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mkqa.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./rerank_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
```
>**Note**:
>
>- You should set `dense_weight`, `sparse_weight` and `colbert_weight` based on the downstream task scenario. If the dense method performs well while the sparse method does not, you can lower `sparse_weight` and increase `dense_weight` accordingly.
>
>- Based on our experience, dividing the sentence pairs to be reranked into several shards and computing scores for each shard on a single GPU tends to be more efficient than using multiple GPUs to compute scores for all sentence pairs directly.Therefore, if your machine have multiple GPUs, you can set `num_shards` to the number of GPUs and launch multiple terminals to execute the command (`shard_id` should be equal to `cuda_id`). Therefore, if you have multiple GPUs on your machine, you can launch multiple terminals and run multiple commands simultaneously. Make sure to set the `shard_id` and `cuda_id` appropriately, and ensure that you have computed scores for all shards before proceeding to the second step.
4. (*Optional*) In the 3rd step, you can get all three kinds of scores, saved to `rerank_result_save_dir/dense/{encoder}-{reranker}`, `rerank_result_save_dir/sparse/{encoder}-{reranker}` and `rerank_result_save_dir/colbert/{encoder}-{reranker}`. If you want to try other weights, you don't need to rerun the 4th step. Instead, you can use [this script](./multi_vector_rerank/hybrid_all_results.py) to hybrid the three kinds of scores directly.
```bash
cd multi_vector_rerank
# 1. Hybrid All Search Results
python hybrid_all_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--dense_search_result_save_dir ./rerank_results/dense \
--sparse_search_result_save_dir ./rerank_results/sparse \
--colbert_search_result_save_dir ./rerank_results/colbert \
--hybrid_result_save_dir ./hybrid_search_results \
--top_k 200 \
--threads 32 \
--dense_weight 1 --sparse_weight 0.1 --colbert_weight 1
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mkqa.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./hybrid_search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_hybrid_results \
--metrics recall@20 recall@100 \
--threads 32
```
## BM25 Baseline
We provide two methods of evaluating BM25 baseline:
1. Use the same tokenizer with [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3) (i.e., tokenizer of [XLM-Roberta](https://huggingface.co/FacebookAI/xlm-roberta-large)):
```bash
cd sparse_retrieval
# 1. Output Search Results with BM25
python bm25_baseline_same_tokenizer.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder bm25_same_tokenizer \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
```
2. Use the language analyzer provided by [Anserini](https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/analysis/AnalyzerMap.java) ([Lucene Tokenizer](https://github.com/apache/lucene/tree/main/lucene/analysis/common/src/java/org/apache/lucene/analysis)):
```bash
cd sparse_retrieval
# 1. Output Search Results with BM25
python bm25_baseline.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder bm25 \
--languages ar da de es fi fr he hu it ja km ko ms nl no pl pt ru sv th tr vi zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
```
"""
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--index_save_dir ./corpus-index \
--max_passage_length 512 \
--batch_size 256 \
--fp16 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import sys
import faiss
import datasets
import numpy as np
from tqdm import tqdm
from pprint import pprint
from FlagEmbedding import FlagModel
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
fp16: bool = field(
default=True,
metadata={'help': 'Use fp16 in inference?'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
@dataclass
class EvalArgs:
index_save_dir: str = field(
default='./corpus-index',
metadata={'help': 'Dir to save index. Corpus index will be saved to `index_save_dir/{encoder_name}/index`. Corpus ids will be saved to `index_save_dir/{encoder_name}/docid` .'}
)
max_passage_length: int = field(
default=512,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def get_model(model_args: ModelArgs):
model = FlagModel(
model_args.encoder,
pooling_method=model_args.pooling_method,
normalize_embeddings=model_args.normalize_embeddings,
use_fp16=model_args.fp16
)
return model
def parse_corpus(corpus: datasets.Dataset):
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "content": content})
corpus = datasets.Dataset.from_list(corpus_list)
return corpus
def generate_index(model: FlagModel, corpus: datasets.Dataset, max_passage_length: int=512, batch_size: int=256):
corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_passage_length)
dim = corpus_embeddings.shape[-1]
faiss_index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
corpus_embeddings = corpus_embeddings.astype(np.float32)
faiss_index.train(corpus_embeddings)
faiss_index.add(corpus_embeddings)
return faiss_index, list(corpus["id"])
def save_result(index: faiss.Index, docid: list, index_save_dir: str):
docid_save_path = os.path.join(index_save_dir, 'docid')
index_save_path = os.path.join(index_save_dir, 'index')
with open(docid_save_path, 'w', encoding='utf-8') as f:
for _id in docid:
f.write(str(_id) + '\n')
faiss.write_index(index, index_save_path)
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
model = get_model(model_args=model_args)
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating embedding with model:")
print(model_args.encoder)
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder))
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
if os.path.exists(os.path.join(index_save_dir, 'index')) and not eval_args.overwrite:
print(f'Embedding already exists. Skip...')
return
corpus = datasets.load_dataset("BeIR/nq", 'corpus')['corpus']
corpus = parse_corpus(corpus=corpus)
index, docid = generate_index(
model=model,
corpus=corpus,
max_passage_length=eval_args.max_passage_length,
batch_size=eval_args.batch_size
)
save_result(index, docid, index_save_dir)
print("==================================================")
print("Finish generating embeddings with following model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
"""
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--batch_size 32 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
"""
import os
import sys
import torch
import datasets
from tqdm import tqdm
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser, is_torch_npu_available
from pyserini.search.faiss import FaissSearcher, AutoQueryEncoder
from pyserini.output_writer import get_output_writer, OutputFormat
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
add_instruction: bool = field(
default=False,
metadata={'help': 'Add query-side instruction?'}
)
query_instruction_for_retrieval: str = field(
default=None,
metadata={'help': 'query instruction for retrieval'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
index_save_dir: str = field(
default='./corpus-index',
metadata={'help': 'Dir to index and docid. Corpus index path is `index_save_dir/{encoder_name}/index`. Corpus ids path is `index_save_dir/{encoder_name}/docid` .'}
)
result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving search results. Search results will be saved to `result_save_dir/{encoder_name}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
threads: int = field(
default=1,
metadata={'help': 'Maximum threads to use during search'}
)
batch_size: int = field(
default=32,
metadata={'help': 'Search batch size.'}
)
hits: int = field(
default=1000,
metadata={'help': 'Number of hits'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def get_query_encoder(model_args: ModelArgs):
if torch.cuda.is_available():
device = torch.device("cuda")
elif is_torch_npu_available():
device = torch.device("npu")
else:
device = torch.device("cpu")
model = AutoQueryEncoder(
encoder_dir=model_args.encoder,
device=device,
pooling=model_args.pooling_method,
l2_norm=model_args.normalize_embeddings
)
return model
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_queries_and_qids(qa_data_dir: str, lang: str, add_instruction: bool=False, query_instruction_for_retrieval: str=None):
topics_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
if not os.path.exists(topics_path):
raise FileNotFoundError(f"{topics_path} not found")
dataset = datasets.load_dataset('json', data_files=topics_path)['train']
queries = []
qids = []
for data in dataset:
qids.append(str(data['id']))
queries.append(str(data['question']))
if add_instruction and query_instruction_for_retrieval is not None:
queries = [f"{query_instruction_for_retrieval}{query}" for query in queries]
return queries, qids
def save_result(search_results, result_save_path: str, qids: list, max_hits: int):
output_writer = get_output_writer(result_save_path, OutputFormat(OutputFormat.TREC.value), 'w',
max_hits=max_hits, tag='Faiss', topics=qids,
use_max_passage=False,
max_passage_delimiter='#',
max_passage_hits=1000)
with output_writer:
for topic, hits in search_results:
output_writer.write(topic, hits)
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
query_encoder = get_query_encoder(model_args=model_args)
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder))
if not os.path.exists(index_save_dir):
raise FileNotFoundError(f"{index_save_dir} not found")
searcher = FaissSearcher(
index_dir=index_save_dir,
query_encoder=query_encoder
)
print("==================================================")
print("Start generating search results with model:", encoder)
print('Generate search results of following languages: ', languages)
for lang in languages:
print("**************************************************")
print(f"Start searching results of {lang} ...")
result_save_path = os.path.join(eval_args.result_save_dir, os.path.basename(encoder), f"{lang}.txt")
if not os.path.exists(os.path.dirname(result_save_path)):
os.makedirs(os.path.dirname(result_save_path))
if os.path.exists(result_save_path) and not eval_args.overwrite:
print(f'Search results of {lang} already exists. Skip...')
continue
queries, qids = get_queries_and_qids(eval_args.qa_data_dir, lang=lang, add_instruction=model_args.add_instruction)
search_results = []
for start_idx in tqdm(range(0, len(queries), eval_args.batch_size), desc="Searching"):
batch_queries = queries[start_idx : start_idx+eval_args.batch_size]
batch_qids = qids[start_idx : start_idx+eval_args.batch_size]
batch_search_results = searcher.batch_search(
queries=batch_queries,
q_ids=batch_qids,
k=eval_args.hits,
threads=eval_args.threads
)
search_results.extend([(_id, batch_search_results[_id]) for _id in batch_qids])
save_result(
search_results=search_results,
result_save_path=result_save_path,
qids=qids,
max_hits=eval_args.hits
)
print("==================================================")
print("Finish generating search results with following model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
"""
# 1. Generate Corpus Embedding
python step0-generate_embedding.py \
--encoder BAAI/bge-m3 \
--index_save_dir ./corpus-index \
--max_passage_length 512 \
--batch_size 256 \
--fp16 \
--add_instruction False \
--pooling_method cls \
--normalize_embeddings True
# 2. Search Results
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--index_save_dir ./corpus-index \
--result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--threads 16 \
--batch_size 32 \
--hits 1000 \
--pooling_method cls \
--normalize_embeddings True \
--add_instruction False
# 3. Print and Save Evaluation Results
python step2-eval_dense_mkqa.py \
--encoder BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import sys
import json
import datasets
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
from utils.evaluation import evaluate_recall_qa
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
search_result_save_dir: str = field(
default='./output_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
metrics: str = field(
default="recall@20",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./eval_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{encoder}.json`'}
)
threads: int = field(
default=1,
metadata={"help": "num of evaluation threads. <= 1 means single thread"}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, pooling_method: str, normalize_embeddings: bool, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'model': model_name,
'pooling_method': pooling_method,
'normalize_embeddings': normalize_embeddings,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{model_name}` on `{eval_languages}` saved at `{save_path}`')
def get_corpus_dict():
corpus_dict = {}
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
for data in tqdm(corpus, desc="Loading corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_dict[_id] = content
return corpus_dict
def get_qa_dict(qa_path: str):
qa_dict = {}
dataset = datasets.load_dataset('json', data_files=qa_path)['train']
for data in dataset:
qid = str(data['id'])
answers = data['answers']
qa_dict[qid] = answers
return qa_dict
def get_search_result_dict(search_result_path: str, top_k: int=100):
search_result_dict = {}
flag = True
for _, row in tqdm(pd.read_csv(search_result_path, sep=' ', header=None).iterrows(), desc="Loading search results"):
qid = str(row.iloc[0])
docid = str(row.iloc[2])
rank = int(row.iloc[3])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append(docid)
return search_result_dict
def evaluate(corpus_dict: dict, qa_dict: dict, search_result_path: str, metrics: list):
top_k = max([int(metric.split('@')[-1]) for metric in metrics])
search_result_dict = get_search_result_dict(search_result_path, top_k=int(top_k))
search_results = []
ground_truths = []
for qid, docid_list in tqdm(search_result_dict.items(), desc="Preparing to evaluate"):
answers = qa_dict[qid]
doc_list = [corpus_dict[docid] for docid in docid_list]
search_results.append(doc_list)
ground_truths.append(answers)
results = {}
metrics = sorted([metric.lower() for metric in metrics])
for metric in metrics:
metric, k = metric.split('@')
k = int(k)
assert metric in ['recall'], f"Metric `{metric}` is not supported."
if metric == 'recall':
results[f'Recall@{k}'] = evaluate_recall_qa(search_results, ground_truths, k=k)
return results
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
corpus_dict = get_corpus_dict()
languages = check_languages(eval_args.languages)
if eval_args.encoder[-1] == '/':
eval_args.encoder = eval_args.encoder[:-1]
if 'checkpoint-' in os.path.basename(eval_args.encoder):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
results = {}
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
results_list = []
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
results_list.append(pool.apply_async(evaluate, args=(corpus_dict, qa_dict, search_result_path, eval_args.metrics)))
pool.close()
pool.join()
for i, lang in enumerate(languages):
results[lang] = results_list[i].get()
else:
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(corpus_dict, qa_dict, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=eval_args.encoder,
pooling_method=eval_args.pooling_method,
normalize_embeddings=eval_args.normalize_embeddings,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(eval_args.encoder)}.json"),
eval_languages=languages
)
print("==================================================")
print("Finish generating evaluation results with following model:")
print(eval_args.encoder)
if __name__ == "__main__":
main()
"""
python step0-hybrid_search_results.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--dense_search_result_save_dir ../dense_retrieval/search_results \
--sparse_search_result_save_dir ../sparse_retrieval/search_results \
--hybrid_result_save_dir ./search_results \
--top_k 1000 \
--dense_weight 1 --sparse_weight 0.3 \
--threads 32
"""
import os
import pandas as pd
from tqdm import tqdm
import multiprocessing
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class EvalArgs:
model_name_or_path: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of model'}
)
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
top_k: int = field(
default=1000,
metadata={'help': 'Use reranker to rerank top-k retrieval results'}
)
dense_weight: float = field(
default=1,
metadata={'help': 'Hybrid weight of dense score'}
)
sparse_weight: float = field(
default=0.3,
metadata={'help': 'Hybrid weight of sparse score'}
)
dense_search_result_save_dir: str = field(
default='../dense_retrieval/search_results',
metadata={'help': 'Dir to saving dense search results. Search results path is `dense_search_result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
sparse_search_result_save_dir: str = field(
default='../sparse_retrieval/search_results',
metadata={'help': 'Dir to saving sparse search results. Search results path is `sparse_search_result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
hybrid_result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving hybrid search results. Reranked results will be saved to `hybrid_result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
threads: int = field(
default=1,
metadata={"help": "num of evaluation threads. <= 1 means single thread"}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_search_result_dict(search_result_path: str, top_k: int=1000):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = row.iloc[2]
rank = int(row.iloc[3])
score = float(row.iloc[4])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append((docid, score))
return search_result_dict
def get_queries_dict(queries_path: str):
queries_dict = {}
for _, row in pd.read_csv(queries_path, sep='\t', header=None).iterrows():
qid = str(row.iloc[0])
query = row.iloc[1]
queries_dict[qid] = query
return queries_dict
def save_hybrid_results(sparse_search_result_path: str, dense_search_result_path: str, hybrid_result_save_path: str, top_k: int=1000, dense_weight: float=0.2, sparse_weight: float=0.5):
sparse_search_result_dict = get_search_result_dict(sparse_search_result_path, top_k=top_k)
dense_search_result_dict = get_search_result_dict(dense_search_result_path, top_k=top_k)
if not os.path.exists(os.path.dirname(hybrid_result_save_path)):
os.makedirs(os.path.dirname(hybrid_result_save_path))
qid_list = list(set(sparse_search_result_dict.keys()) | set(dense_search_result_dict.keys()))
hybrid_results_list = []
for qid in tqdm(qid_list, desc="Hybriding dense and sparse scores"):
results = {}
if qid in sparse_search_result_dict:
for docid, score in sparse_search_result_dict[qid]:
score = score / 10000.
results[docid] = score * sparse_weight
if qid in dense_search_result_dict:
for docid, score in dense_search_result_dict[qid]:
if docid in results:
results[docid] = results[docid] + score * dense_weight
else:
results[docid] = score * dense_weight
hybrid_results = [(docid, score) for docid, score in results.items()]
hybrid_results.sort(key=lambda x: x[1], reverse=True)
hybrid_results_list.append(hybrid_results[:top_k])
with open(hybrid_result_save_path, 'w', encoding='utf-8') as f:
for qid, hybrid_results in tqdm(zip(qid_list, hybrid_results_list), desc="Saving hybrid search results"):
for rank, docid_score in enumerate(hybrid_results):
docid, score = docid_score
line = f"{qid} Q0 {docid} {rank+1} {score:.6f} Faiss-&-Anserini"
f.write(line + '\n')
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if os.path.basename(eval_args.model_name_or_path).startswith('checkpoint-'):
eval_args.model_name_or_path = os.path.dirname(eval_args.model_name_or_path) + '_' + os.path.basename(eval_args.model_name_or_path)
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
for lang in languages:
print("**************************************************")
print(f"Start hybrid search results of {lang} ...")
hybrid_result_save_path = os.path.join(eval_args.hybrid_result_save_dir, f"{os.path.basename(eval_args.model_name_or_path)}", f"{lang}.txt")
sparse_search_result_save_dir = os.path.join(eval_args.sparse_search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
sparse_search_result_path = os.path.join(sparse_search_result_save_dir, f"{lang}.txt")
dense_search_result_save_dir = os.path.join(eval_args.dense_search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
dense_search_result_path = os.path.join(dense_search_result_save_dir, f"{lang}.txt")
pool.apply_async(save_hybrid_results, args=(
sparse_search_result_path,
dense_search_result_path,
hybrid_result_save_path,
eval_args.top_k,
eval_args.dense_weight,
eval_args.sparse_weight
))
pool.close()
pool.join()
else:
for lang in languages:
print("**************************************************")
print(f"Start hybrid search results of {lang} ...")
hybrid_result_save_path = os.path.join(eval_args.hybrid_result_save_dir, f"{os.path.basename(eval_args.model_name_or_path)}", f"{lang}.txt")
sparse_search_result_save_dir = os.path.join(eval_args.sparse_search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
sparse_search_result_path = os.path.join(sparse_search_result_save_dir, f"{lang}.txt")
dense_search_result_save_dir = os.path.join(eval_args.dense_search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
dense_search_result_path = os.path.join(dense_search_result_save_dir, f"{lang}.txt")
save_hybrid_results(
sparse_search_result_path=sparse_search_result_path,
dense_search_result_path=dense_search_result_path,
hybrid_result_save_path=hybrid_result_save_path,
top_k=eval_args.top_k,
dense_weight=eval_args.dense_weight,
sparse_weight=eval_args.sparse_weight
)
print("==================================================")
print("Finish generating reranked results with following model:", eval_args.model_name_or_path)
if __name__ == "__main__":
main()
"""
# Ref: https://github.com/texttron/tevatron/tree/main/examples/unicoil
# 1. Search Dense and Sparse Results
Dense Retrieval
Sparse Retrieval
# 2. Hybrid Dense and Sparse Search Results
python step0-hybrid_search_results.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--dense_search_result_save_dir ../dense_retrieval/search_results \
--sparse_search_result_save_dir ../sparse_retrieval/search_results \
--hybrid_result_save_dir ./search_results \
--top_k 1000 \
--dense_weight 1 --sparse_weight 0.3 \
--threads 32
# 3. Print and Save Evaluation Results
python step1-eval_hybrid_mkqa.py \
--model_name_or_path BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32 \
--pooling_method cls \
--normalize_embeddings True
"""
import os
import sys
import json
import datasets
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from pprint import pprint
from typing import Optional
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
from utils.evaluation import evaluate_recall_qa
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
model_name_or_path: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of model'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
search_result_save_dir: str = field(
default='./output_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{model_name_or_path}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
metrics: str = field(
default="recall@20",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./eval_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{model_name_or_path}.json`'}
)
threads: int = field(
default=1,
metadata={"help": "num of evaluation threads. <= 1 means single thread"}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, pooling_method: str, normalize_embeddings: bool, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'model': model_name,
'pooling_method': pooling_method,
'normalize_embeddings': normalize_embeddings,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{model_name}` on `{eval_languages}` saved at `{save_path}`')
def get_corpus_dict():
corpus_dict = {}
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
for data in tqdm(corpus, desc="Loading corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_dict[_id] = content
return corpus_dict
def get_qa_dict(qa_path: str):
qa_dict = {}
dataset = datasets.load_dataset('json', data_files=qa_path)['train']
for data in dataset:
qid = str(data['id'])
answers = data['answers']
qa_dict[qid] = answers
return qa_dict
def get_search_result_dict(search_result_path: str, top_k: int=100):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = str(row.iloc[2])
rank = int(row.iloc[3])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append(docid)
return search_result_dict
def evaluate(corpus_dict: dict, qa_dict: dict, search_result_path: str, metrics: list):
top_k = max([int(metric.split('@')[-1]) for metric in metrics])
search_result_dict = get_search_result_dict(search_result_path, top_k=int(top_k))
search_results = []
ground_truths = []
for qid, docid_list in search_result_dict.items():
answers = qa_dict[qid]
doc_list = [corpus_dict[docid] for docid in docid_list]
search_results.append(doc_list)
ground_truths.append(answers)
results = {}
metrics = sorted([metric.lower() for metric in metrics])
for metric in metrics:
metric, k = metric.split('@')
k = int(k)
assert metric in ['recall'], f"Metric `{metric}` is not supported."
if metric == 'recall':
results[f'Recall@{k}'] = evaluate_recall_qa(search_results, ground_truths, k=k)
return results
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
corpus_dict = get_corpus_dict()
languages = check_languages(eval_args.languages)
if eval_args.model_name_or_path[-1] == '/':
eval_args.model_name_or_path = eval_args.model_name_or_path[:-1]
if os.path.basename(eval_args.model_name_or_path).startswith('checkpoint-'):
eval_args.model_name_or_path = os.path.dirname(eval_args.model_name_or_path) + '_' + os.path.basename(eval_args.model_name_or_path)
results = {}
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
results_list = []
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
results_list.append(pool.apply_async(evaluate, args=(corpus_dict, qa_dict, search_result_path, eval_args.metrics)))
pool.close()
pool.join()
for i, lang in enumerate(languages):
results[lang] = results_list[i].get()
else:
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.model_name_or_path))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(corpus_dict, qa_dict, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=eval_args.model_name_or_path,
pooling_method=eval_args.pooling_method,
normalize_embeddings=eval_args.normalize_embeddings,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(eval_args.model_name_or_path)}.json"),
eval_languages=languages
)
print("==================================================")
print("Finish generating evaluation results with following model:")
print(eval_args.model_name_or_path)
if __name__ == "__main__":
main()
"""
python hybrid_all_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--dense_search_result_save_dir ./rerank_results/dense \
--sparse_search_result_save_dir ./rerank_results/sparse \
--colbert_search_result_save_dir ./rerank_results/colbert \
--hybrid_result_save_dir ./hybrid_search_results \
--top_k 200 \
--threads 32 \
--dense_weight 1 --sparse_weight 0.1 --colbert_weight 1
"""
import os
import pandas as pd
from tqdm import tqdm
import multiprocessing
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class EvalArgs:
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of model'}
)
reranker: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of reranker'}
)
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
top_k: int = field(
default=200,
metadata={'help': 'Use reranker to rerank top-k retrieval results'}
)
dense_weight: float = field(
default=1,
metadata={'help': 'Hybrid weight of sparse score'}
)
sparse_weight: float = field(
default=0.3,
metadata={'help': 'Hybrid weight of sparse score'}
)
colbert_weight: float = field(
default=1,
metadata={'help': 'Hybrid weight of sparse score'}
)
dense_search_result_save_dir: str = field(
default='../rerank/unify_rerank_results/dense',
metadata={'help': 'Dir to saving dense search results. Search results path is `dense_search_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
sparse_search_result_save_dir: str = field(
default='../rerank/unify_rerank_results/sparse',
metadata={'help': 'Dir to saving sparse search results. Search results path is `sparse_search_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
colbert_search_result_save_dir: str = field(
default='../rerank/unify_rerank_results/colbert',
metadata={'help': 'Dir to saving sparse search results. Search results path is `sparse_search_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
hybrid_result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving hybrid search results. Reranked results will be saved to `hybrid_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
threads: int = field(
default=1,
metadata={"help": "num of evaluation threads. <= 1 means single thread"}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_search_result_dict(search_result_path: str, top_k: int=1000):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = row.iloc[2]
rank = int(row.iloc[3])
score = float(row.iloc[4])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append((docid, score))
return search_result_dict
def get_queries_dict(queries_path: str):
queries_dict = {}
for _, row in pd.read_csv(queries_path, sep='\t', header=None).iterrows():
qid = str(row.iloc[0])
query = row.iloc[1]
queries_dict[qid] = query
return queries_dict
def save_hybrid_results(sparse_search_result_dict: dict, dense_search_result_dict: dict, colbert_search_result_dict: dict, hybrid_result_save_path: str, top_k: int=1000, dense_weight: float=1, sparse_weight: float=0.3, colbert_weight: float=1):
if not os.path.exists(os.path.dirname(hybrid_result_save_path)):
os.makedirs(os.path.dirname(hybrid_result_save_path))
qid_list = list(set(sparse_search_result_dict.keys()) | set(dense_search_result_dict.keys()) | set(colbert_search_result_dict.keys()))
hybrid_results_list = []
for qid in tqdm(qid_list, desc="Hybriding dense, sparse and colbert scores"):
results = {}
if qid in sparse_search_result_dict:
for docid, score in sparse_search_result_dict[qid]:
score = score / 0.3 # use 0.3 to restore
results[docid] = score * sparse_weight
if qid in dense_search_result_dict:
for docid, score in dense_search_result_dict[qid]:
if docid in results:
results[docid] = results[docid] + score * dense_weight
else:
results[docid] = score * dense_weight
if qid in colbert_search_result_dict:
for docid, score in colbert_search_result_dict[qid]:
if docid in results:
results[docid] = results[docid] + score * colbert_weight
else:
results[docid] = score * colbert_weight
hybrid_results = [(docid, score) for docid, score in results.items()]
hybrid_results.sort(key=lambda x: x[1], reverse=True)
hybrid_results_list.append(hybrid_results[:top_k])
with open(hybrid_result_save_path, 'w', encoding='utf-8') as f:
for qid, hybrid_results in tqdm(zip(qid_list, hybrid_results_list), desc="Saving hybrid search results"):
for rank, docid_score in enumerate(hybrid_results):
docid, score = docid_score
line = f"{qid} Q0 {docid} {rank+1} {score:.6f} Faiss-&-Anserini"
f.write(line + '\n')
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if os.path.basename(eval_args.encoder).startswith('checkpoint-'):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
if os.path.basename(eval_args.reranker).startswith('checkpoint-'):
eval_args.reranker = os.path.dirname(eval_args.reranker) + '_' + os.path.basename(eval_args.reranker)
dir_name = f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}"
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
for lang in languages:
print("**************************************************")
print(f"Start hybrid search results of {lang} ...")
hybrid_result_save_path = os.path.join(eval_args.hybrid_result_save_dir, dir_name, f"{lang}.txt")
sparse_search_result_save_dir = os.path.join(eval_args.sparse_search_result_save_dir, dir_name)
sparse_search_result_path = os.path.join(sparse_search_result_save_dir, f"{lang}.txt")
sparse_search_result_dict = get_search_result_dict(sparse_search_result_path, top_k=eval_args.top_k)
dense_search_result_save_dir = os.path.join(eval_args.dense_search_result_save_dir, dir_name)
dense_search_result_path = os.path.join(dense_search_result_save_dir, f"{lang}.txt")
dense_search_result_dict = get_search_result_dict(dense_search_result_path, top_k=eval_args.top_k)
colbert_search_result_save_dir = os.path.join(eval_args.colbert_search_result_save_dir, dir_name)
colbert_search_result_path = os.path.join(colbert_search_result_save_dir, f"{lang}.txt")
colbert_search_result_dict = get_search_result_dict(colbert_search_result_path, top_k=eval_args.top_k)
pool.apply_async(save_hybrid_results, args=(
sparse_search_result_dict,
dense_search_result_dict,
colbert_search_result_dict,
hybrid_result_save_path,
eval_args.top_k,
eval_args.dense_weight,
eval_args.sparse_weight,
eval_args.colbert_weight
))
pool.close()
pool.join()
else:
for lang in languages:
print("**************************************************")
print(f"Start hybrid search results of {lang} ...")
hybrid_result_save_path = os.path.join(eval_args.hybrid_result_save_dir, dir_name, f"{lang}.txt")
sparse_search_result_save_dir = os.path.join(eval_args.sparse_search_result_save_dir, dir_name)
sparse_search_result_path = os.path.join(sparse_search_result_save_dir, f"{lang}.txt")
sparse_search_result_dict = get_search_result_dict(sparse_search_result_path, top_k=eval_args.top_k)
dense_search_result_save_dir = os.path.join(eval_args.dense_search_result_save_dir, dir_name)
dense_search_result_path = os.path.join(dense_search_result_save_dir, f"{lang}.txt")
dense_search_result_dict = get_search_result_dict(dense_search_result_path, top_k=eval_args.top_k)
colbert_search_result_save_dir = os.path.join(eval_args.colbert_search_result_save_dir, dir_name)
colbert_search_result_path = os.path.join(colbert_search_result_save_dir, f"{lang}.txt")
colbert_search_result_dict = get_search_result_dict(colbert_search_result_path, top_k=eval_args.top_k)
save_hybrid_results(
sparse_search_result_dict=sparse_search_result_dict,
dense_search_result_dict=dense_search_result_dict,
colbert_search_result_dict=colbert_search_result_dict,
hybrid_result_save_path=hybrid_result_save_path,
top_k=eval_args.top_k,
dense_weight=eval_args.dense_weight,
sparse_weight=eval_args.sparse_weight,
colbert_weight=eval_args.colbert_weight
)
print("==================================================")
print("Finish generating reranked results with following model and reranker:")
print(eval_args.encoder)
print(eval_args.reranker)
if __name__ == "__main__":
main()
"""
python step0-rerank_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ../dense_retrieval/search_results \
--qa_data_dir ../qa_data \
--rerank_result_save_dir ./rerank_results \
--top_k 100 \
--batch_size 4 \
--max_length 512 \
--pooling_method cls \
--normalize_embeddings True \
--dense_weight 1 --sparse_weight 0.3 --colbert_weight 1 \
--num_shards 1 --shard_id 0 --cuda_id 0
"""
import os
import sys
import copy
import datasets
import pandas as pd
from tqdm import tqdm
from FlagEmbedding import BGEM3FlagModel
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
@dataclass
class ModelArgs:
reranker: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of reranker'}
)
fp16: bool = field(
default=False,
metadata={'help': 'Use fp16 in inference?'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
max_length: int = field(
default=512,
metadata={'help': 'Max text length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
top_k: int = field(
default=100,
metadata={'help': 'Use reranker to rerank top-k retrieval results'}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
search_result_save_dir: str = field(
default='./output_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='./qa_data',
metadata={'help': 'Dir to qa data.'}
)
rerank_result_save_dir: str = field(
default='./rerank_results',
metadata={'help': 'Dir to saving reranked results. Reranked results will be saved to `rerank_result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
num_shards: int = field(
default=1,
metadata={'help': "num of shards"}
)
shard_id: int = field(
default=0,
metadata={'help': 'id of shard, start from 0'}
)
cuda_id: int = field(
default=0,
metadata={'help': 'CUDA ID to use. -1 means only use CPU.'}
)
dense_weight: float = field(
default=1,
metadata={'help': 'The weight of dense score when hybriding all scores'}
)
sparse_weight: float = field(
default=0.3,
metadata={'help': 'The weight of sparse score when hybriding all scores'}
)
colbert_weight: float = field(
default=1,
metadata={'help': 'The weight of colbert score when hybriding all scores'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def get_reranker(model_args: ModelArgs, device: str=None):
reranker = BGEM3FlagModel(
model_name_or_path=model_args.reranker,
pooling_method=model_args.pooling_method,
normalize_embeddings=model_args.normalize_embeddings,
device=device
)
return reranker
def get_search_result_dict(search_result_path: str, top_k: int=100):
search_result_dict = {}
flag = True
for _, row in pd.read_csv(search_result_path, sep=' ', header=None).iterrows():
qid = str(row.iloc[0])
docid = row.iloc[2]
rank = int(row.iloc[3])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append(docid)
return search_result_dict
def get_queries_dict(queries_path: str):
queries_dict = {}
dataset = datasets.load_dataset('json', data_files=queries_path)['train']
for data in dataset:
qid = str(data['id'])
query = data['question']
queries_dict[qid] = query
return queries_dict
def get_corpus_dict(corpus: datasets.Dataset):
corpus_dict = {}
for data in tqdm(corpus, desc="Loading corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_dict[_id] = content
return corpus_dict
def save_rerank_results(queries_dict: dict, corpus_dict: dict, reranker: BGEM3FlagModel, search_result_dict: dict, rerank_result_save_path: dict, batch_size: int=256, max_length: int=512, dense_weight: float=1, sparse_weight: float=0.3, colbert_weight: float=1):
qid_list = []
sentence_pairs = []
for qid, docids in search_result_dict.items():
qid_list.append(qid)
query = queries_dict[qid]
for docid in docids:
passage = corpus_dict[docid]
sentence_pairs.append((query, passage))
scores_dict = reranker.compute_score(
sentence_pairs,
batch_size=batch_size,
max_query_length=max_length,
max_passage_length=max_length,
weights_for_different_modes=[dense_weight, sparse_weight, colbert_weight]
)
for sub_dir, _rerank_result_save_path in rerank_result_save_path.items():
if not os.path.exists(os.path.dirname(_rerank_result_save_path)):
os.makedirs(os.path.dirname(_rerank_result_save_path))
scores = scores_dict[sub_dir]
with open(_rerank_result_save_path, 'w', encoding='utf-8') as f:
i = 0
for qid in qid_list:
docids = search_result_dict[qid]
docids_scores = []
for j in range(len(docids)):
docids_scores.append((docids[j], scores[i + j]))
i += len(docids)
docids_scores.sort(key=lambda x: x[1], reverse=True)
for rank, docid_score in enumerate(docids_scores):
docid, score = docid_score
line = f"{qid} Q0 {docid} {rank+1} {score:.6f} Faiss"
f.write(line + '\n')
def get_shard(search_result_dict: dict, num_shards: int, shard_id: int):
if num_shards <= 1:
return search_result_dict
keys_list = sorted(list(search_result_dict.keys()))
shard_len = len(keys_list) // num_shards
if shard_id == num_shards - 1:
shard_keys_list = keys_list[shard_id*shard_len:]
else:
shard_keys_list = keys_list[shard_id*shard_len : (shard_id + 1)*shard_len]
shard_search_result_dict = {k: search_result_dict[k] for k in shard_keys_list}
return shard_search_result_dict
def rerank_results(corpus_dict: dict, languages: list, eval_args: EvalArgs, model_args: ModelArgs, device: str=None):
eval_args = copy.deepcopy(eval_args)
model_args = copy.deepcopy(model_args)
num_shards = eval_args.num_shards
shard_id = eval_args.shard_id
if shard_id >= num_shards:
raise ValueError(f"shard_id >= num_shards ({shard_id} >= {num_shards})")
reranker = get_reranker(model_args=model_args, device=device)
if os.path.basename(eval_args.encoder).startswith('checkpoint-'):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
if os.path.basename(model_args.reranker).startswith('checkpoint-'):
model_args.reranker = os.path.dirname(model_args.reranker) + '_' + os.path.basename(model_args.reranker)
for lang in languages:
print("**************************************************")
print(f"Start reranking results of {lang} ...")
queries_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
queries_dict = get_queries_dict(queries_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, os.path.basename(eval_args.encoder))
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
search_result_dict = get_search_result_dict(search_result_path, top_k=eval_args.top_k)
search_result_dict = get_shard(search_result_dict, num_shards=num_shards, shard_id=shard_id)
rerank_result_save_path = {}
for sub_dir in ['colbert', 'sparse', 'dense', 'colbert+sparse+dense']:
_rerank_result_save_path = os.path.join(
eval_args.rerank_result_save_dir,
sub_dir,
f"{os.path.basename(eval_args.encoder)}-{os.path.basename(model_args.reranker)}",
f"{lang}.txt")
rerank_result_save_path[sub_dir] = _rerank_result_save_path
save_rerank_results(
queries_dict=queries_dict,
corpus_dict=corpus_dict,
reranker=reranker,
search_result_dict=search_result_dict,
rerank_result_save_path=rerank_result_save_path,
batch_size=eval_args.batch_size,
max_length=eval_args.max_length,
dense_weight=eval_args.dense_weight,
)
def main():
parser = HfArgumentParser([EvalArgs, ModelArgs])
eval_args, model_args = parser.parse_args_into_dataclasses()
eval_args: EvalArgs
model_args: ModelArgs
languages = check_languages(eval_args.languages)
corpus = datasets.load_dataset("BeIR/nq", 'corpus')['corpus']
corpus_dict = get_corpus_dict(corpus=corpus)
cuda_id = eval_args.cuda_id
if cuda_id < 0:
rerank_results(corpus_dict, languages, eval_args, model_args, device='cpu')
else:
rerank_results(corpus_dict, languages, eval_args, model_args, device=f"cuda:{cuda_id}")
print("==================================================")
print("Finish generating reranked results with following model and reranker:")
print(eval_args.encoder)
print(model_args.reranker)
if __name__ == "__main__":
main()
"""
# 1. Rerank Search Results
python step0-rerank_results.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ../dense_retrieval/search_results \
--qa_data_dir ../qa_data \
--rerank_result_save_dir ./rerank_results \
--top_k 100 \
--batch_size 4 \
--max_length 512 \
--pooling_method cls \
--normalize_embeddings True \
--dense_weight 1 --sparse_weight 0.3 --colbert_weight 1 \
--num_shards 1 --shard_id 0 --cuda_id 0
# 2. Print and Save Evaluation Results
python step1-eval_rerank_mkqa.py \
--encoder BAAI/bge-m3 \
--reranker BAAI/bge-m3 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./rerank_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
"""
import os
import sys
import json
import datasets
import numpy as np
import pandas as pd
from tqdm import tqdm
import multiprocessing
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
sys.path.append("..")
from utils.normalize_text import normalize
from utils.evaluation import evaluate_recall_qa
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: en ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw',
"nargs": "+"}
)
reranker: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of reranker'}
)
encoder: str = field(
default='BAAI/bge-m3',
metadata={'help': 'Name or path of encoder'}
)
search_result_save_dir: str = field(
default='./rerank_results',
metadata={'help': 'Dir to saving search results. Search results path is `result_save_dir/{encoder}-{reranker}/{lang}.txt`'}
)
qa_data_dir: str = field(
default='../qa_data',
metadata={'help': 'Dir to qa data.'}
)
metrics: str = field(
default="recall@20",
metadata={'help': 'Metrics to evaluate. Avaliable metrics: recall@k',
"nargs": "+"}
)
eval_result_save_dir: str = field(
default='./reranker_evaluation_results',
metadata={'help': 'Dir to saving evaluation results. Evaluation results will be saved to `eval_result_save_dir/{encoder}-{reranker}.json`'}
)
threads: int = field(
default=1,
metadata={"help": "num of evaluation threads. <= 1 means single thread"}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def compute_average(results: dict):
average_results = {}
for _, result in results.items():
for metric, score in result.items():
if metric not in average_results:
average_results[metric] = []
average_results[metric].append(score)
for metric, scores in average_results.items():
average_results[metric] = np.mean(scores)
return average_results
def save_results(model_name: str, reranker_name: str, results: dict, save_path: str, eval_languages: list):
try:
results['average'] = compute_average(results)
except:
results['average'] = None
pass
pprint(results)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
results_dict = {
'reranker': reranker_name,
'model': model_name,
'results': results
}
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(results_dict, f, indent=4, ensure_ascii=False)
print(f'Results of evaluating `{reranker_name}` on `{eval_languages}` based on `{model_name}` saved at `{save_path}`')
def get_corpus_dict():
corpus_dict = {}
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
for data in tqdm(corpus, desc="Loading corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_dict[_id] = content
return corpus_dict
def get_qa_dict(qa_path: str):
qa_dict = {}
dataset = datasets.load_dataset('json', data_files=qa_path)['train']
for data in dataset:
qid = str(data['id'])
answers = data['answers']
qa_dict[qid] = answers
return qa_dict
def get_search_result_dict(search_result_path: str, top_k: int=100):
search_result_dict = {}
flag = True
for _, row in tqdm(pd.read_csv(search_result_path, sep=' ', header=None).iterrows(), desc="Loading search results"):
qid = str(row.iloc[0])
docid = str(row.iloc[2])
rank = int(row.iloc[3])
if qid not in search_result_dict:
search_result_dict[qid] = []
flag = False
if rank > top_k:
flag = True
if flag:
continue
else:
search_result_dict[qid].append(docid)
return search_result_dict
def evaluate(corpus_dict: dict, qa_dict: dict, search_result_path: str, metrics: list):
top_k = max([int(metric.split('@')[-1]) for metric in metrics])
search_result_dict = get_search_result_dict(search_result_path, top_k=int(top_k))
search_results = []
ground_truths = []
for qid, docid_list in tqdm(search_result_dict.items(), desc="Preparing to evaluate"):
answers = qa_dict[qid]
doc_list = [corpus_dict[docid] for docid in docid_list]
search_results.append(doc_list)
ground_truths.append(answers)
results = {}
metrics = sorted([metric.lower() for metric in metrics])
for metric in metrics:
metric, k = metric.split('@')
k = int(k)
assert metric in ['recall'], f"Metric `{metric}` is not supported."
if metric == 'recall':
results[f'Recall@{k}'] = evaluate_recall_qa(search_results, ground_truths, k=k)
return results
def main():
parser = HfArgumentParser([EvalArgs])
eval_args = parser.parse_args_into_dataclasses()[0]
eval_args: EvalArgs
corpus_dict = get_corpus_dict()
languages = check_languages(eval_args.languages)
if 'checkpoint-' in os.path.basename(eval_args.encoder):
eval_args.encoder = os.path.dirname(eval_args.encoder) + '_' + os.path.basename(eval_args.encoder)
if 'checkpoint-' in os.path.basename(eval_args.reranker):
eval_args.reranker = os.path.dirname(eval_args.reranker) + '_' + os.path.basename(eval_args.reranker)
try:
for sub_dir in ['colbert', 'sparse', 'dense', 'colbert+sparse+dense']:
results = {}
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
results_list = []
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, sub_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}")
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
results_list.append(pool.apply_async(evaluate, args=(corpus_dict, qa_dict, search_result_path, eval_args.metrics)))
pool.close()
pool.join()
for i, lang in enumerate(languages):
results[lang] = results_list[i].get()
else:
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, sub_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}")
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(corpus_dict, qa_dict, search_result_path, eval_args.metrics)
results[lang] = result
print("****************************")
print(sub_dir + ":")
save_results(
model_name=eval_args.encoder,
reranker_name=eval_args.reranker,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, sub_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}.json"),
eval_languages=languages
)
except:
results = {}
if eval_args.threads > 1:
threads = min(len(languages), eval_args.threads)
pool = multiprocessing.Pool(processes=threads)
results_list = []
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}")
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
results_list.append(pool.apply_async(evaluate, args=(corpus_dict, qa_dict, search_result_path, eval_args.metrics)))
pool.close()
pool.join()
for i, lang in enumerate(languages):
results[lang] = results_list[i].get()
else:
for lang in languages:
print("*****************************")
print(f"Start evaluating {lang} ...")
qa_path = os.path.join(eval_args.qa_data_dir, f"{lang}.jsonl")
qa_dict = get_qa_dict(qa_path)
search_result_save_dir = os.path.join(eval_args.search_result_save_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}")
search_result_path = os.path.join(search_result_save_dir, f"{lang}.txt")
result = evaluate(corpus_dict, qa_dict, search_result_path, eval_args.metrics)
results[lang] = result
save_results(
model_name=eval_args.encoder,
reranker_name=eval_args.reranker,
results=results,
save_path=os.path.join(eval_args.eval_result_save_dir, f"{os.path.basename(eval_args.encoder)}-{os.path.basename(eval_args.reranker)}.json"),
eval_languages=languages
)
print("==================================================")
print("Finish generating evaluation results with following model and reranker:")
print(eval_args.encoder)
print(eval_args.reranker)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment