Commit cbe9fe2c authored by Ezra-Yu's avatar Ezra-Yu Committed by gaotong
Browse files

Add Release Contraibution

parent 36f11110
"""BM25 Retriever."""
from typing import List, Optional
import numpy as np
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
from opencompass.registry import ICL_RETRIEVERS
logger = get_logger(__name__)
@ICL_RETRIEVERS.register_module()
class BM25Retriever(BaseRetriever):
"""BM25 Retriever. In information retrieval, Okapi BM25 (BM is an
abbreviation of best matching) is a ranking function used by search engines
to estimate the relevance of documents to a given search query. You can
find more details in https://en.wikipedia.org/wiki/Okapi_BM25. Each in-
context example of the test prompts is retrieved by the BM25 Algorithm.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults
to '\n'.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to '\n'.
ice_num (`Optional[int]`): The number of in-context example template
when origin `PromptTemplate` is provided. Defaults to 1.
index_split (`Optional[str]`): The split of the dataset to retrieve the
in-context example index, used when `dataset_reader.dataset` is an
instance of `datasets.Dataset`. Defaults to 'train'.
test_split (`Optional[str]`): The split of the dataset to retrieve the
in-context example, used when `dataset_reader.dataset` is an
instance of `datasets.Dataset`. Defaults to 'test'.
"""
bm25 = None
index_corpus = None
test_corpus = None
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.index_corpus = [
word_tokenize(data) for data in
self.dataset_reader.generate_input_field_corpus(self.index_ds)
]
self.bm25 = BM25Okapi(self.index_corpus)
self.test_corpus = [
word_tokenize(data) for data in
self.dataset_reader.generate_input_field_corpus(self.test_ds)
]
def retrieve(self) -> List[List]:
"""Retrieve the in-context example index for each test example."""
rtr_idx_list = []
logger.info('Retrieving data for test set...')
for idx in trange(len(self.test_corpus),
disable=not self.is_main_process):
query = self.test_corpus[idx]
scores = self.bm25.get_scores(query)
near_ids = list(np.argsort(scores)[::-1][:self.ice_num])
near_ids = [int(a) for a in near_ids]
rtr_idx_list.append(near_ids)
return rtr_idx_list
"""Random Retriever."""
from typing import Optional
import numpy as np
from tqdm import trange
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
logger = get_logger(__name__)
class RandomRetriever(BaseRetriever):
"""Random Retriever. Each in-context example of the test prompts is
retrieved in a random way.
**WARNING**: This class has not been tested thoroughly. Please use it with
caution.
"""
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1,
seed: Optional[int] = 43) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.seed = seed
def retrieve(self):
np.random.seed(self.seed)
num_idx = len(self.index_ds)
rtr_idx_list = []
logger.info('Retrieving data for test set...')
for _ in trange(len(self.test_ds), disable=not self.is_main_process):
idx_list = np.random.choice(num_idx, self.ice_num,
replace=False).tolist()
rtr_idx_list.append(idx_list)
return rtr_idx_list
"""Zeroshot Retriever."""
from typing import List, Optional
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.registry import ICL_RETRIEVERS
@ICL_RETRIEVERS.register_module()
class ZeroRetriever(BaseRetriever):
"""Zeroshot Retriever. The retriever returns empty list for all queries.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to ''.
"""
def __init__(self, dataset, ice_eos_token: Optional[str] = '') -> None:
super().__init__(dataset, '', ice_eos_token, 0)
def retrieve(self) -> List[List]:
rtr_idx_list = [[] for _ in range(len(self.test_ds))]
return rtr_idx_list
from mmengine.logging import MMLogger
def get_logger(log_level='INFO') -> MMLogger:
"""Get the logger for OpenCompass.
Args:
log_level (str): The log level. Default: 'INFO'. Choices are 'DEBUG',
'INFO', 'WARNING', 'ERROR', 'CRITICAL'.
"""
return MMLogger.get_instance('OpenCompass',
logger_name='OpenCompass',
log_level=log_level)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment