Commit fb111087 authored by yingfhu's avatar yingfhu
Browse files

[Feat] support opencompass

parent 7d346000
# Few-shot
\ No newline at end of file
import re
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from .base import BaseDataset
@LOAD_DATASET.register_module()
class TheoremQADataset(BaseDataset):
@staticmethod
def load(path: str):
return load_dataset('csv', data_files={'test': path})
@TEXT_POSTPROCESSORS.register_module('TheoremQA')
def TheoremQA_postprocess(text: str) -> str:
text = text.strip().split('\n')[0].strip()
matches = re.findall(r'answer is (.*)', text)
if len(matches) == 0:
return text
else:
text = matches[0].strip()[:-1]
return text
import json
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class CBDataset_V2(BaseDataset):
@staticmethod
def load(path):
dataset = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
line['label'] = {
'contradiction': 'A',
'entailment': 'B',
'neutral': 'C'
}[line['label']]
dataset.append(line)
return Dataset.from_list(dataset)
import json
from datasets import Dataset, load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class CHIDDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
content = example['content']
for i, c in enumerate(example['candidates']):
example[f'content{i}'] = content.replace('#idiom#', c)
return example
dataset = dataset.map(preprocess)
return dataset
@LOAD_DATASET.register_module()
class CHIDDataset_V2(BaseDataset):
@staticmethod
def load(path):
data = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
item = {}
item['content'] = line['content'].replace('#idiom#', '______')
for i, c in enumerate(line['candidates']):
item[chr(ord('A') + i)] = c
item['answer'] = 'ABCDEFG'[line['answer']]
data.append(item)
return Dataset.from_list(data)
from datasets import DatasetDict, load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class CivilCommentsDataset(BaseDataset):
@staticmethod
def load(**kwargs):
train_dataset = load_dataset(**kwargs, split='train')
test_dataset = load_dataset(**kwargs, split='test')
def pre_process(example):
example['label'] = int(example['toxicity'] >= 0.5)
example['choices'] = ['no', 'yes']
return example
def remove_columns(dataset):
return dataset.remove_columns([
'severe_toxicity', 'obscene', 'threat', 'insult',
'identity_attack', 'sexual_explicit'
])
train_dataset = remove_columns(train_dataset)
test_dataset = remove_columns(test_dataset)
test_dataset = test_dataset.shuffle(seed=42)
test_dataset = test_dataset.select(list(range(10000)))
test_dataset = test_dataset.map(pre_process)
return DatasetDict({
'train': train_dataset,
'test': test_dataset,
})
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class commonsenseqaDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def pre_process(example):
for i in range(5):
example[chr(ord('A') + i)] = example['choices']['text'][i]
return example
dataset = dataset.map(pre_process).remove_columns(
['question_concept', 'id', 'choices'])
return dataset
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class crowspairsDataset(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
example['label'] = 0
return example
return dataset.map(preprocess)
@LOAD_DATASET.register_module()
class crowspairsDataset_V2(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
example['label'] = 'A'
return example
return dataset.map(preprocess)
import json
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class eprstmtDataset_V2(BaseDataset):
@staticmethod
def load(path):
data = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
item = {
'sentence': line['sentence'],
'label': {
'Positive': 'A',
'Negative': 'B',
}[line['label']],
}
data.append(item)
return Dataset.from_list(data)
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class HFDataset(BaseDataset):
@staticmethod
def load(**kwargs):
return load_dataset(**kwargs)
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class piqaDataset_V2(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
assert isinstance(example['label'], int)
if example['label'] < 0:
example['answer'] = 'NULL'
else:
example['answer'] = 'AB'[example['label']]
example.pop('label')
return example
dataset = dataset.map(preprocess)
return dataset
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class RealToxicPromptsDataset(BaseDataset):
@staticmethod
def load(**kwargs):
challenging_subset = kwargs.pop('challenging_subset', False)
dataset = load_dataset(**kwargs)
def preprocess(example):
for k, v in example['prompt'].items():
k = 'prompt_' + k
example[k] = v
del example['prompt']
return example
dataset = dataset.map(preprocess)
# return challenging subset if necessary
if challenging_subset:
return dataset.filter(lambda example: example['challenging'])
return dataset
from datasets import load_dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class siqaDataset_V2(BaseDataset):
@staticmethod
def load(**kwargs):
dataset = load_dataset(**kwargs)
def preprocess(example):
example['label'] = ' ABC'[int(example['label'])]
return example
dataset = dataset.map(preprocess)
return dataset
from .icl_aucroc_evaluator import AUCROCEvaluator
from .icl_base_evaluator import BaseEvaluator
from .icl_em_evaluator import EMEvaluator
from .icl_hf_evaluator import * # noqa
from .icl_toxic_evaluator import ToxicEvaluator
import logging
import torch.distributed as dist
LOG_LEVEL = logging.INFO
SUBPROCESS_LOG_LEVEL = logging.ERROR
LOG_FORMATTER = '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s'
def get_logger(name, level=LOG_LEVEL, log_file=None, file_mode='w'):
formatter = logging.Formatter(LOG_FORMATTER)
logger = logging.getLogger(name)
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
if rank == 0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(formatter)
file_handler.setLevel(level)
logger.addHandler(file_handler)
if rank == 0:
logger.setLevel(level)
else:
logger.setLevel(SUBPROCESS_LOG_LEVEL)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
stream_handler.setLevel(level)
logger.addHandler(stream_handler)
return logger
from .naive import * # noqa: F401, F403
from .size import * # noqa: F401, F403
from .abbr import * # noqa
from .build import * # noqa
from .fileio import * # noqa
from .git import * # noqa
from .lark import * # noqa
from .logging import * # noqa
from .menu import * # noqa
from .prompt import * # noqa
from .summarizer import * # noqa
from .text_postprocessors import * # noqa
import copy
from mmengine.config import ConfigDict
from opencompass.registry import LOAD_DATASET, MODELS
def build_dataset_from_cfg(dataset_cfg: ConfigDict) -> ConfigDict:
dataset_cfg = copy.deepcopy(dataset_cfg)
dataset_cfg.pop('infer_cfg', None)
dataset_cfg.pop('eval_cfg', None)
dataset_cfg.pop('abbr', None)
return LOAD_DATASET.build(dataset_cfg)
def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict:
model_cfg = copy.deepcopy(model_cfg)
model_cfg.pop('run_cfg', None)
model_cfg.pop('max_out_len', None)
model_cfg.pop('batch_size', None)
model_cfg.pop('abbr', None)
return MODELS.build(model_cfg)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment