Unverified Commit 32f40a8f authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Sync] Sync with internal codes 2023.01.08 (#777)

parent 8194199d
_longeval_2k = ['classification_en_2k', 'lines_2k', 'qa_en_2k', 'qa_zh_2k', 'stackselect_2k', 'summarization_en_2k', 'textsort_2k']
_longeval_4k = ['classification_en_4k', 'lines_4k', 'qa_en_4k', 'qa_zh_4k', 'stackselect_4k', 'summarization_en_4k', 'textsort_4k']
_longeval_8k = ['classification_en_8k', 'lines_8k', 'qa_en_8k', 'qa_zh_8k', 'stackselect_8k', 'summarization_en_8k', 'textsort_8k']
_longeval_15k = ['classification_en_15k', 'lines_15k', 'qa_en_15k', 'qa_zh_15k', 'stackselect_15k', 'summarization_en_15k', 'textsort_15k']
_longeval_30k = ['classification_en_30k', 'lines_30k', 'qa_en_30k', 'qa_zh_30k', 'stackselect_30k', 'summarization_en_30k', 'textsort_30k']
longeval_summary_groups = [
{'name': 'longeval_v2_2k', 'subsets': _longeval_2k},
{'name': 'longeval_v2_4k', 'subsets': _longeval_4k},
{'name': 'longeval_v2_8k', 'subsets': _longeval_8k},
{'name': 'longeval_v2_15k', 'subsets': _longeval_15k},
{'name': 'longeval_v2_30k', 'subsets': _longeval_30k},
{'name': 'longeval_v2', 'subsets': _longeval_2k + _longeval_4k + _longeval_8k + _longeval_15k + _longeval_30k}
]
summarizer = dict(
dataset_abbrs = [
'longeval_v2',
'longeval_v2_2k',
'longeval_v2_4k',
'longeval_v2_8k',
'longeval_v2_15k',
'longeval_v2_30k',
'classification_en_2k',
'classification_en_4k',
'classification_en_8k',
'classification_en_15k',
'classification_en_30k',
'lines_2k',
'lines_4k',
'lines_8k',
'lines_15k',
'lines_30k',
'qa_en_2k',
'qa_en_4k',
'qa_en_8k',
'qa_en_15k',
'qa_en_30k',
'qa_zh_2k',
'qa_zh_4k',
'qa_zh_8k',
'qa_zh_15k',
'qa_zh_30k',
'stackselect_2k',
'stackselect_4k',
'stackselect_8k',
'stackselect_15k',
'stackselect_30k',
'summarization_en_2k',
'summarization_en_4k',
'summarization_en_8k',
'summarization_en_15k',
'summarization_en_30k',
'textsort_2k',
'textsort_4k',
'textsort_8k',
'textsort_15k',
'textsort_30k',
],
summary_groups=longeval_summary_groups,
)
...@@ -46,9 +46,11 @@ from .hellaswag import * # noqa: F401, F403 ...@@ -46,9 +46,11 @@ from .hellaswag import * # noqa: F401, F403
from .huggingface import * # noqa: F401, F403 from .huggingface import * # noqa: F401, F403
from .humaneval import * # noqa: F401, F403 from .humaneval import * # noqa: F401, F403
from .humanevalx import * # noqa: F401, F403 from .humanevalx import * # noqa: F401, F403
from .hungarian_math import * # noqa: F401, F403
from .infinitebench import * # noqa: F401, F403 from .infinitebench import * # noqa: F401, F403
from .iwslt2017 import * # noqa: F401, F403 from .iwslt2017 import * # noqa: F401, F403
from .jigsawmultilingual import * # noqa: F401, F403 from .jigsawmultilingual import * # noqa: F401, F403
from .jsonl import JsonlDataset # noqa: F401, F403
from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403 from .kaoshi import KaoshiDataset, KaoshiEvaluator # noqa: F401, F403
from .lambada import * # noqa: F401, F403 from .lambada import * # noqa: F401, F403
from .lawbench import * # noqa: F401, F403 from .lawbench import * # noqa: F401, F403
...@@ -57,6 +59,7 @@ from .leval import * # noqa: F401, F403 ...@@ -57,6 +59,7 @@ from .leval import * # noqa: F401, F403
from .longbench import * # noqa: F401, F403 from .longbench import * # noqa: F401, F403
from .mastermath2024v1 import * # noqa: F401, F403 from .mastermath2024v1 import * # noqa: F401, F403
from .math import * # noqa: F401, F403 from .math import * # noqa: F401, F403
from .math401 import * # noqa: F401, F403
from .mathbench import * # noqa: F401, F403 from .mathbench import * # noqa: F401, F403
from .mbpp import * # noqa: F401, F403 from .mbpp import * # noqa: F401, F403
from .medbench import * # noqa: F401, F403 from .medbench import * # noqa: F401, F403
......
...@@ -69,13 +69,105 @@ def load_experiment(file: str) -> dict: ...@@ -69,13 +69,105 @@ def load_experiment(file: str) -> dict:
) )
def load_experiment_template(file: str) -> dict:
"""Load single experiment file with solutions for template experiment."""
with open(file, 'r') as f:
notebook = json.load(f)
example = notebook['cells']
metadata = notebook['metadata']
modules = metadata.get('modules', [])
if modules:
# these two annotations should be the same
assert len(modules) == len(metadata.get('step_types'))
# reformat annotations
modules = [[_m.strip() for _m in _modules.split('&')]
for _modules in modules]
questions = []
source_codes = []
outputs = []
tags = []
for cell in example:
if cell['cell_type'] == 'markdown':
text = ''.join(cell['source']).strip()
if modules:
_modules = modules.pop(0)
if 'chinese' not in file:
text += f"Please use {' and '.join(_modules)} modules."
else:
text += f"请用 {' 和 '.join(_modules)} 模块."
text = text.strip() + '\n'
# append the formatted text
questions.append(text)
elif cell['cell_type'] == 'code':
source_codes.append(''.join(cell['source']))
output_flag = False
if cell['outputs']:
for _output in cell['outputs']:
if _output['output_type'] == 'display_data':
assert not output_flag
output_flag = True
tags.append('vis')
outputs.append(_output['data']['image/png'])
for _output in cell['outputs']:
if output_flag:
break
if _output['output_type'] == 'stream' and _output[
'name'] == 'stdout':
assert not output_flag
output_flag = True
tags.append('general')
outputs.append(''.join(_output['text']))
elif _output['output_type'] == 'execute_result':
assert not output_flag
output_flag = True
tags.append('general')
outputs.append(''.join(
_output['data']['text/plain']))
if not output_flag:
# no output fallback to exec
tags.append('exec')
outputs.append(None)
return dict(
experiment=file,
questions=sum(([
dict(role='user', content=question),
dict(role='assistant', content=source_code)
] for question, source_code in zip(questions, source_codes)), []),
references=dict(outputs=outputs,
tags=tags,
metadata=metadata,
experiment=file),
)
def check_internet():
"""A tricky way to check internet."""
import socket
import nltk
socket.setdefaulttimeout(10)
ret = nltk.download('stopwords', quiet=True)
socket.setdefaulttimeout(None)
if not ret:
raise ConnectionError('CIBench needs internet to get response. Please'
'check your internet and proxy.')
@LOAD_DATASET.register_module() @LOAD_DATASET.register_module()
class CIBenchDataset(BaseDataset): class CIBenchDataset(BaseDataset):
"""Code Interpreter dataset.""" """Code Interpreter dataset."""
@staticmethod @staticmethod
def load(path: str): def load(path: str, internet_check: bool = False):
"""Load whole dataset.""" """Load whole dataset.
Args:
path(str): Path of cibench dataset.
internet_check(bool): Whether to check internet.
Defaults to False.
"""
if internet_check:
check_internet()
assert os.path.exists(path), f'Path {path} does not exist.' assert os.path.exists(path), f'Path {path} does not exist.'
data_list = [] data_list = []
for cwd, dirs, files in os.walk(path): for cwd, dirs, files in os.walk(path):
...@@ -83,11 +175,36 @@ class CIBenchDataset(BaseDataset): ...@@ -83,11 +175,36 @@ class CIBenchDataset(BaseDataset):
files.sort() files.sort()
for f in files: for f in files:
if '.ipynb' in f: if '.ipynb' in f:
try:
data = load_experiment(os.path.join(cwd, f)) data = load_experiment(os.path.join(cwd, f))
except Exception: data_list.append(data)
print(f'Error with file {os.path.join(cwd, f)}')
continue dataset = Dataset.from_list(data_list)
return dataset
@LOAD_DATASET.register_module()
class CIBenchTemplateDataset(BaseDataset):
"""Code Interpreter dataset for template dataset."""
@staticmethod
def load(path: str, internet_check: bool = False):
"""Load whole dataset.
Args:
path(str): Path of cibench dataset.
internet_check(bool): Whether to check internet.
Defaults to False.
"""
if internet_check:
check_internet()
assert os.path.exists(path), f'Path {path} does not exist.'
data_list = []
for cwd, dirs, files in os.walk(path):
dirs.sort()
files.sort()
for f in files:
if '.ipynb' in f:
data = load_experiment_template(os.path.join(cwd, f))
data_list.append(data) data_list.append(data)
dataset = Dataset.from_list(data_list) dataset = Dataset.from_list(data_list)
...@@ -138,7 +255,8 @@ class CIBenchEvaluator(BaseEvaluator): ...@@ -138,7 +255,8 @@ class CIBenchEvaluator(BaseEvaluator):
def check_user_data_dir(self, user_data_dir): def check_user_data_dir(self, user_data_dir):
if user_data_dir == 'ENV': if user_data_dir == 'ENV':
user_data_dir = os.environ.get('USER_DATA_DIR', '') default_path = osp.abspath('./data/cibench_dataset/datasources')
user_data_dir = os.environ.get('USER_DATA_DIR', default_path)
user_data_dir = user_data_dir.rstrip('/') user_data_dir = user_data_dir.rstrip('/')
basename = osp.basename(user_data_dir) basename = osp.basename(user_data_dir)
if basename and basename != 'data': if basename and basename != 'data':
...@@ -172,10 +290,11 @@ class CIBenchEvaluator(BaseEvaluator): ...@@ -172,10 +290,11 @@ class CIBenchEvaluator(BaseEvaluator):
if action['result']: if action['result']:
try: try:
pred = action['result']['text'] pred = action['result']['text']
match = re.search('```\n(.*?)\n```', pred, re.DOTALL) match = re.search('execute_result:\n\n```\n(.*?)\n```',
pred, re.DOTALL)
if match: if match:
out = match.group(1) out = match.group(1)
return out == target or out in target return out.strip() == target.strip()
except Exception: except Exception:
return False return False
# Fall back to False # Fall back to False
...@@ -313,24 +432,24 @@ class CIBenchEvaluator(BaseEvaluator): ...@@ -313,24 +432,24 @@ class CIBenchEvaluator(BaseEvaluator):
# numeric_correct: numerical correct # numeric_correct: numerical correct
# text_score: text score # text_score: text score
# vis_sim: visual similarity # vis_sim: visual similarity
result = defaultdict(list)
for tag, step, output in zip(tags, steps, outputs):
# check whether this step is valid
result['executable'].append(self.valid_step(step))
if tag != 'exec':
key, func = self.TAG_MAPPING[tag]
result[key].append(func(step, output))
# add missing metric for better analyse if not exists # create empty results
result = dict()
if hard_tags: if hard_tags:
check_tags = ['exec', 'num', 'text', 'vis'] check_tags = ['exec', 'num', 'text', 'vis']
else: else:
check_tags = ['exec', 'general', 'vis'] check_tags = ['exec', 'general', 'vis']
for tag in check_tags: for tag in check_tags:
key = self.TAG_MAPPING[tag][0] key = self.TAG_MAPPING[tag][0]
if key not in result:
result[key] = [] result[key] = []
for tag, step, output in zip(tags, steps, outputs):
# check whether this step is valid
result['executable'].append(self.valid_step(step))
if tag != 'exec':
key, func = self.TAG_MAPPING[tag]
result[key].append(func(step, output))
return result return result
def get_output_dir(self): def get_output_dir(self):
......
...@@ -183,8 +183,13 @@ class CircularDatasetMeta(type): ...@@ -183,8 +183,13 @@ class CircularDatasetMeta(type):
def load(cls, circular_patterns='circular', *args, **kwargs): def load(cls, circular_patterns='circular', *args, **kwargs):
circular_splits = getattr(cls, 'default_circular_splits', None) circular_splits = getattr(cls, 'default_circular_splits', None)
option_keys = cls.default_option_keys option_keys = getattr(cls, 'default_option_keys', None)
if 'option_keys' in kwargs:
option_keys = kwargs.pop('option_keys')
assert option_keys is not None, 'option_keys cannot be None'
answer_key = getattr(cls, 'default_answer_key', None) answer_key = getattr(cls, 'default_answer_key', None)
if 'answer_key' in kwargs:
answer_key = kwargs.pop('answer_key')
answer_key_switch_method = getattr( answer_key_switch_method = getattr(
cls, 'default_answer_key_switch_method', None) cls, 'default_answer_key_switch_method', None)
dataset = cls.dataset_class.load(*args, **kwargs) dataset = cls.dataset_class.load(*args, **kwargs)
...@@ -311,11 +316,11 @@ class CircularEvaluator(BaseEvaluator): ...@@ -311,11 +316,11 @@ class CircularEvaluator(BaseEvaluator):
tmp_metrics.update({f'correct_{k}': 0 for k in circular_patterns}) tmp_metrics.update({f'correct_{k}': 0 for k in circular_patterns})
tmp_metrics.update({f'count_{k}': 0 for k in circular_patterns}) tmp_metrics.update({f'count_{k}': 0 for k in circular_patterns})
# calculate the original accuracy # calculate the original accuracy
for pred, ref, origin_item in zip(predictions, references, test_set): for pred, refr, origin_item in zip(predictions, references, test_set):
circular_pattern = origin_item['circular_pattern'] circular_pattern = origin_item['circular_pattern']
for k in circular_patterns: for k in circular_patterns:
if tuple(circular_pattern) in circular_patterns[k]: if tuple(circular_pattern) in circular_patterns[k]:
tmp_metrics[f'correct_{k}'] += 1 if pred == ref else 0 tmp_metrics[f'correct_{k}'] += 1 if pred == refr else 0
tmp_metrics[f'count_{k}'] += 1 tmp_metrics[f'count_{k}'] += 1
for k in circular_patterns: for k in circular_patterns:
...@@ -324,13 +329,13 @@ class CircularEvaluator(BaseEvaluator): ...@@ -324,13 +329,13 @@ class CircularEvaluator(BaseEvaluator):
# calculate the circular accuracy # calculate the circular accuracy
_details = {k: {} for k in circular_patterns} _details = {k: {} for k in circular_patterns}
for pred, ref, origin_item in zip(predictions, references, test_set): for pred, refr, origin_item in zip(predictions, references, test_set):
index = origin_item['qid'] index = origin_item['qid']
circular_pattern = origin_item['circular_pattern'] circular_pattern = origin_item['circular_pattern']
for k in circular_patterns: for k in circular_patterns:
if tuple(circular_pattern) in circular_patterns[k]: if tuple(circular_pattern) in circular_patterns[k]:
_details[k].setdefault( _details[k].setdefault(
index, []).append(True if pred == ref else False) index, []).append(True if pred == refr else False)
for k in _details: for k in _details:
_details[k] = { _details[k] = {
index: sum(_details[k][index]) index: sum(_details[k][index])
......
import copy
import csv import csv
import json import json
import os import os
from typing import List
from datasets import Dataset from datasets import Dataset
from opencompass.openicl.icl_evaluator import AccEvaluator from opencompass.datasets.circular import (CircularDatasetMeta,
CircularEvaluator)
from opencompass.openicl.icl_evaluator import AccEvaluator, BaseEvaluator
from opencompass.openicl.icl_inferencer import GenInferencer, PPLInferencer from opencompass.openicl.icl_inferencer import GenInferencer, PPLInferencer
from opencompass.openicl.icl_prompt_template import PromptTemplate from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.registry import LOAD_DATASET from opencompass.registry import LOAD_DATASET
from opencompass.utils.text_postprocessors import first_option_postprocess
from .base import BaseDataset from .base import BaseDataset
class OptionSimAccEvaluator(BaseEvaluator):
def __init__(self, options) -> None:
super().__init__()
if not all((isinstance(i, str) and i.isupper() and len(i) == 1)
for i in options):
raise ValueError(
f'Each options should be single upper letter, got {options}')
self.options = options
def match_any_label(self, pred, test_item):
from rapidfuzz.distance import Levenshtein as L
from opencompass.utils.text_postprocessors import \
first_option_postprocess
pred = pred.strip()
if any([pred == i for i in self.options]):
parsed = pred
else:
parsed = ''
if parsed == '':
parsed = first_option_postprocess(pred,
''.join(self.options),
cushion=False)
if parsed == '':
possible_options = []
for opt in self.options:
opt_str = test_item[opt]
if opt_str is not None and opt_str.lower() in pred.lower():
possible_options.append(opt)
if len(possible_options) == 1:
parsed = possible_options[0]
if parsed == '':
dists = []
for opt in self.options:
opt_str = test_item[opt]
if opt_str is None:
continue
cands = [opt, opt_str, opt + '. ' + opt_str]
d = min(L.distance(pred, cand) for cand in cands)
dists.append((d, opt))
if len(dists) > 0:
parsed = min(dists)[1]
return parsed
def score(self, predictions: List, references: List, test_set) -> dict:
assert len(predictions) == len(references)
num_correct, num_total = 0, 0
details = {}
for index in range(len(predictions)):
pred = predictions[index]
refr = references[index]
parsed = self.match_any_label(pred, test_set[index])
num_correct += 1 if parsed == refr else 0
num_total += 1
details[str(index)] = {}
details[str(index)]['pred'] = pred
details[str(index)]['parsed'] = parsed
details[str(index)]['refr'] = refr
details[str(index)]['correct'] = parsed == refr
return {'accuracy': num_correct / num_total * 100, 'details': details}
# TODO: DO NOT COPY YOURSELF!!!
class CircularOptionSimAccEvaluator(OptionSimAccEvaluator):
def __init__(self, options, circular_pattern='circular'):
super().__init__(options)
self.circular_pattern = circular_pattern
def score(self, predictions, references, test_set):
from opencompass.datasets.circular import (get_all_possible_patterns,
get_circular_patterns,
get_origin_patterns)
circular_patterns = {}
circular_patterns['origin'] = get_origin_patterns(
test_set[0]['circular_pattern'])
circular_patterns['circular'] = get_circular_patterns(
test_set[0]['circular_pattern'])
if self.circular_pattern == 'all_possible':
circular_patterns['all_possible'] = get_all_possible_patterns(
test_set[0]['circular_pattern'])
metrics = {}
tmp_metrics = {}
tmp_metrics.update({f'correct_{k}': 0 for k in circular_patterns})
tmp_metrics.update({f'count_{k}': 0 for k in circular_patterns})
# calculate the original accuracy
for pred, refr, origin_item in zip(predictions, references, test_set):
parsed = self.match_any_label(pred, origin_item)
circular_pattern = origin_item['circular_pattern']
for k in circular_patterns:
if tuple(circular_pattern) in circular_patterns[k]:
tmp_metrics[f'correct_{k}'] += (1 if parsed == refr else 0)
tmp_metrics[f'count_{k}'] += 1
for k in circular_patterns:
metrics[f'acc_{k}'] = (tmp_metrics[f'correct_{k}'] /
tmp_metrics[f'count_{k}'] * 100)
# calculate the circular accuracy
_details = {k: {} for k in circular_patterns}
for pred, refr, origin_item in zip(predictions, references, test_set):
index = origin_item['qid']
parsed = self.match_any_label(pred, origin_item)
circular_pattern = origin_item['circular_pattern']
for k in circular_patterns:
if tuple(circular_pattern) in circular_patterns[k]:
_details[k].setdefault(
index, []).append(True if parsed == refr else False)
for k in _details:
_details[k] = {
index: sum(_details[k][index])
for index in _details[k]
}
for k in _details:
for j in range(1, len(circular_patterns[k]) + 1):
count = sum([_details[k][index] >= j for index in _details[k]])
total = len(_details[k])
if j != len(circular_patterns[k]):
metrics[f'more_{j}_{k}'] = count / total * 100
else:
metrics[f'perf_{k}'] = count / total * 100
# make details
details = {}
for index in range(len(predictions)):
parsed = self.match_any_label(predictions[index], test_set[index])
details[str(index)] = {}
if 'question' in test_set[index]:
details[str(index)]['question'] = test_set[index]['question']
details[str(index)]['pred'] = predictions[index]
details[str(index)]['parsed'] = parsed
details[str(index)]['refr'] = references[index]
details[str(index)]['correct'] = parsed == references[index]
metrics['details'] = details
return metrics
@LOAD_DATASET.register_module() @LOAD_DATASET.register_module()
class CustomDataset(BaseDataset): class CustomDataset(BaseDataset):
@staticmethod @staticmethod
def load(path): def load(path):
if path.endswith('.jsonl'): if path.endswith('.jsonl'):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8-sig') as f:
data = [json.loads(line) for line in f] data = [json.loads(line) for line in f]
elif path.endswith('.csv'): elif path.endswith('.csv'):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8-sig') as f:
reader = csv.reader(f) reader = csv.reader(f)
header = next(reader) header = next(reader)
data = [dict(zip(header, row)) for row in reader] data = [dict(zip(header, row)) for row in reader]
...@@ -33,6 +179,10 @@ class CustomDataset(BaseDataset): ...@@ -33,6 +179,10 @@ class CustomDataset(BaseDataset):
return Dataset.from_list(data) return Dataset.from_list(data)
class CircularCustomDataset(CustomDataset, metaclass=CircularDatasetMeta):
dataset_class = CustomDataset
def stringfy_types(obj): def stringfy_types(obj):
for k, v in obj.items(): for k, v in obj.items():
if k == 'type': if k == 'type':
...@@ -69,12 +219,12 @@ def make_mcq_gen_config(meta): ...@@ -69,12 +219,12 @@ def make_mcq_gen_config(meta):
inferencer=dict(type=GenInferencer), inferencer=dict(type=GenInferencer),
) )
eval_cfg = dict(evaluator=dict(type=AccEvaluator), eval_cfg = dict(
evaluator=dict(type=meta.get('evaluator', OptionSimAccEvaluator),
**meta.get('evaluator_kwargs',
{'options': meta['options']})),
pred_role='BOT', pred_role='BOT',
pred_postprocessor=dict( )
type=first_option_postprocess,
options=''.join(meta['options']),
))
dataset = dict( dataset = dict(
abbr=meta['abbr'], abbr=meta['abbr'],
...@@ -87,6 +237,54 @@ def make_mcq_gen_config(meta): ...@@ -87,6 +237,54 @@ def make_mcq_gen_config(meta):
return dataset return dataset
def make_circular_mcq_gen_config(meta):
if meta.get('template', None) is None:
_human_prompt = 'Question: {question}' + ''.join(
[f'\n{item}. {{{item}}}' for item in meta['options']])
human_prompt = meta.get('human_prompt', _human_prompt)
_bot_prompt = f'Answer: {{{meta["output_column"]}}}'
bot_prompt = meta.get('bot_prompt', _bot_prompt)
template = dict(round=[
dict(role='HUMAN', prompt=human_prompt),
dict(role='BOT', prompt=bot_prompt),
])
else:
template = meta['template']
reader_cfg = dict(
input_columns=meta['input_columns'],
output_column=meta['output_column'],
)
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=template,
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)
eval_cfg = dict(
evaluator=dict(type=meta.get('evaluator',
CircularOptionSimAccEvaluator),
**meta.get('evaluator_kwargs',
{'options': meta['options']})),
pred_role='BOT',
)
dataset = dict(
abbr=meta['abbr'],
type=CircularCustomDataset,
option_keys=meta['options'],
answer_key=meta['output_column'],
path=meta['path'],
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
return dataset
def make_qa_gen_config(meta): def make_qa_gen_config(meta):
if meta.get('template', None) is None: if meta.get('template', None) is None:
human_prompt = meta.get('human_prompt', '{question}') human_prompt = meta.get('human_prompt', '{question}')
...@@ -102,7 +300,6 @@ def make_qa_gen_config(meta): ...@@ -102,7 +300,6 @@ def make_qa_gen_config(meta):
]) ])
else: else:
template = meta['template'] template = meta['template']
reader_cfg = dict( reader_cfg = dict(
input_columns=meta['input_columns'], input_columns=meta['input_columns'],
output_column=meta['output_column'], output_column=meta['output_column'],
...@@ -117,7 +314,8 @@ def make_qa_gen_config(meta): ...@@ -117,7 +314,8 @@ def make_qa_gen_config(meta):
) )
eval_cfg = dict( eval_cfg = dict(
evaluator=dict(type=AccEvaluator), evaluator=dict(type=meta.get('evaluator', AccEvaluator),
**meta.get('evaluator_kwargs', {})),
pred_role='BOT', pred_role='BOT',
) )
...@@ -164,7 +362,8 @@ def make_mcq_ppl_config(meta): ...@@ -164,7 +362,8 @@ def make_mcq_ppl_config(meta):
inferencer=dict(type=PPLInferencer), inferencer=dict(type=PPLInferencer),
) )
eval_cfg = dict(evaluator=dict(type=AccEvaluator)) eval_cfg = dict(evaluator=dict(type=meta.get('evaluator', AccEvaluator),
**meta.get('evaluator_kwargs', {})))
dataset = dict( dataset = dict(
abbr=meta['abbr'], abbr=meta['abbr'],
...@@ -177,17 +376,61 @@ def make_mcq_ppl_config(meta): ...@@ -177,17 +376,61 @@ def make_mcq_ppl_config(meta):
return dataset return dataset
def make_circular_mcq_ppl_config(meta):
if meta.get('template', None) is None:
_human_prompt = 'Question: {question}' + ''.join(
[f'\n{item}. {{{item}}}' for item in meta['options']])
human_prompt = meta.get('human_prompt', _human_prompt)
_bot_prompt = f'Answer: {{{meta["output_column"]}}}'
bot_prompt = meta.get('bot_prompt', _bot_prompt)
template = {
answer: dict(round=[
dict(role='HUMAN', prompt=human_prompt),
dict(role='BOT',
prompt=bot_prompt.format(
**{meta['output_column']: answer})),
], )
for answer in meta['options']
}
else:
template = meta['template']
reader_cfg = dict(
input_columns=meta['input_columns'],
output_column=meta['output_column'],
)
infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=template,
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=PPLInferencer),
)
eval_cfg = dict(
evaluator=dict(type=meta.get('evaluator', CircularEvaluator),
**meta.get('evaluator_kwargs', {})))
dataset = dict(
abbr=meta['abbr'],
type=CircularCustomDataset,
option_keys=meta['options'],
answer_key=meta['output_column'],
path=meta['path'],
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg,
)
return dataset
def parse_example_dataset(config): def parse_example_dataset(config):
# try to read meta json # config -> .meta.jsonl -> parsed_results
path = config['path'] path = config['path']
meta_path = config.get('meta_path', path + '.meta.json')
if os.path.exists(meta_path):
with open(meta_path, 'r', encoding='utf-8') as f:
meta = json.load(f)
else:
meta = {}
# load sample # load sample and get parsed_meta
parsed_meta = {}
if path.endswith('.jsonl'): if path.endswith('.jsonl'):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
data_item = json.loads(f.readline()) data_item = json.loads(f.readline())
...@@ -200,11 +443,11 @@ def parse_example_dataset(config): ...@@ -200,11 +443,11 @@ def parse_example_dataset(config):
else: else:
raise ValueError(f'Unsupported ext: {path}, .jsonl or .csv required') raise ValueError(f'Unsupported ext: {path}, .jsonl or .csv required')
meta['path'] = path parsed_meta['path'] = path
input_columns = [i for i in data_item.keys() if i != 'answer'] input_columns = [i for i in data_item.keys() if i != 'answer']
meta.setdefault('input_columns', input_columns) parsed_meta['input_columns'] = input_columns
output_column = 'answer' if 'answer' in data_item else None output_column = 'answer' if 'answer' in data_item else None
meta.setdefault('output_column', output_column) parsed_meta['output_column'] = output_column
options = [] options = []
for i in range(26): for i in range(26):
i = chr(ord('A') + i) i = chr(ord('A') + i)
...@@ -212,19 +455,28 @@ def parse_example_dataset(config): ...@@ -212,19 +455,28 @@ def parse_example_dataset(config):
options.append(i) options.append(i)
else: else:
break break
meta.setdefault('options', options) parsed_meta['options'] = options
abbr = os.path.basename(path).split('.')[0] abbr = os.path.basename(path).split('.')[0]
meta.setdefault('abbr', abbr) parsed_meta['abbr'] = abbr
parsed_meta['data_type'] = 'mcq' if len(options) > 1 else 'qa'
parsed_meta['infer_method'] = 'gen'
if 'data_type' in config: # try to read meta json
meta.setdefault('data_type', config['data_type']) meta_path = config.get('meta_path', path + '.meta.json')
else: if os.path.exists(meta_path):
data_type = 'mcq' if len(options) > 1 else 'qa' with open(meta_path, 'r', encoding='utf-8') as f:
meta.setdefault('data_type', data_type) read_from_file_meta = json.load(f)
if 'infer_method' in config:
meta.setdefault('infer_method', config['infer_method'])
else: else:
meta.setdefault('infer_method', 'gen') read_from_file_meta = {}
# get config meta
config_meta = copy.deepcopy(config)
# merge meta
meta = {}
meta.update(parsed_meta)
meta.update(read_from_file_meta)
meta.update(config_meta)
return meta return meta
...@@ -236,6 +488,8 @@ def make_custom_dataset_config(config): ...@@ -236,6 +488,8 @@ def make_custom_dataset_config(config):
('mcq', 'gen'): make_mcq_gen_config, ('mcq', 'gen'): make_mcq_gen_config,
('mcq', 'ppl'): make_mcq_ppl_config, ('mcq', 'ppl'): make_mcq_ppl_config,
('qa', 'gen'): make_qa_gen_config, ('qa', 'gen'): make_qa_gen_config,
('circular-mcq', 'gen'): make_circular_mcq_gen_config,
('circular-mcq', 'ppl'): make_circular_mcq_ppl_config,
}.get((meta['data_type'], meta['infer_method']), None) }.get((meta['data_type'], meta['infer_method']), None)
if make_config_func is None: if make_config_func is None:
raise ValueError(f'Unsupported dataset data_type: {meta["data_type"]}' raise ValueError(f'Unsupported dataset data_type: {meta["data_type"]}'
......
...@@ -365,7 +365,7 @@ class DS1000ServiceEvaluator(BaseEvaluator): ...@@ -365,7 +365,7 @@ class DS1000ServiceEvaluator(BaseEvaluator):
lib: str, lib: str,
ip_address='localhost', ip_address='localhost',
port=5000, port=5000,
timeout=180) -> None: timeout=600) -> None:
assert lib in _LIBRARY_NAME_LIST, ( assert lib in _LIBRARY_NAME_LIST, (
f' lib must be in {_LIBRARY_NAME_LIST}') f' lib must be in {_LIBRARY_NAME_LIST}')
self.lib = lib self.lib = lib
......
...@@ -5,6 +5,7 @@ import os.path as osp ...@@ -5,6 +5,7 @@ import os.path as osp
import re import re
import subprocess import subprocess
import tempfile import tempfile
import time
from shutil import copyfile from shutil import copyfile
from typing import Dict, Iterable from typing import Dict, Iterable
...@@ -73,7 +74,8 @@ class HumanevalXEvaluator(BaseEvaluator): ...@@ -73,7 +74,8 @@ class HumanevalXEvaluator(BaseEvaluator):
language, language,
ip_address='localhost', ip_address='localhost',
port=5000, port=5000,
timeout=180) -> None: retry=2,
timeout=600) -> None:
assert language in _LANGUAGE_NAME_DICT.keys(), ( assert language in _LANGUAGE_NAME_DICT.keys(), (
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}') f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
if language == 'rust': if language == 'rust':
...@@ -81,6 +83,7 @@ class HumanevalXEvaluator(BaseEvaluator): ...@@ -81,6 +83,7 @@ class HumanevalXEvaluator(BaseEvaluator):
self.language = language self.language = language
self.ip_address = ip_address self.ip_address = ip_address
self.port = port self.port = port
self.retry = retry
self.timeout = timeout self.timeout = timeout
super().__init__() super().__init__()
...@@ -96,7 +99,17 @@ class HumanevalXEvaluator(BaseEvaluator): ...@@ -96,7 +99,17 @@ class HumanevalXEvaluator(BaseEvaluator):
for pred in predictions: for pred in predictions:
f.write(json.dumps(pred) + '\n') f.write(json.dumps(pred) + '\n')
succeed, output = self._code_eval_service(file_path=tmp_out_path) num_retry = 0
while num_retry < self.retry:
succeed, output = self._code_eval_service(
file_path=tmp_out_path)
if not succeed and '(56) Recv failure' in output:
# only retry when connection failed
num_retry += 1
# wait a min in case the service load is too high
time.sleep(60)
else:
break
if succeed: if succeed:
if isinstance(output, str): if isinstance(output, str):
...@@ -104,7 +117,13 @@ class HumanevalXEvaluator(BaseEvaluator): ...@@ -104,7 +117,13 @@ class HumanevalXEvaluator(BaseEvaluator):
elif isinstance(output, dict): elif isinstance(output, dict):
return output return output
ref_url = 'https://github.com/Ezra-Yu/code-evaluator' ref_url = 'https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html' # noqa
if hasattr(self, '_out_dir'):
result_file_path = re.sub('results', 'mid_results',
self._out_dir) + '.json' # noqa
if not osp.exists(osp.dirname(result_file_path)):
os.makedirs(osp.dirname(result_file_path))
else:
result_file_path = os.path.join( result_file_path = os.path.join(
'outputs', f'humanevalx_{self.language}.json') 'outputs', f'humanevalx_{self.language}.json')
copyfile(tmp_out_path, result_file_path) copyfile(tmp_out_path, result_file_path)
......
import pandas as pd
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class HungarianExamMathDataset(BaseDataset):
@staticmethod
def load(path):
df = pd.read_csv(path)
df.columns = ['question']
outputs = [{
'question': question
} for question in df['question'].tolist()]
dataset = Dataset.from_list(outputs)
return dataset
import json
from datasets import Dataset
from opencompass.registry import LOAD_DATASET
from .base import BaseDataset
@LOAD_DATASET.register_module()
class JsonlDataset(BaseDataset):
@staticmethod
def load(path):
data = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line))
return Dataset.from_list(data)
import json import json
import re
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
...@@ -9,22 +10,7 @@ from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET, ...@@ -9,22 +10,7 @@ from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
from .base import BaseDataset from .base import BaseDataset
@LOAD_DATASET.register_module() def last_boxed_only_string(string):
class MATHDataset(BaseDataset):
@staticmethod
def load(path: str):
def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
return s[len(left):-1]
except Exception:
return None
def last_boxed_only_string(string):
idx = string.rfind('\\boxed') idx = string.rfind('\\boxed')
if idx < 0: if idx < 0:
idx = string.rfind('\\fbox') idx = string.rfind('\\fbox')
...@@ -51,23 +37,34 @@ class MATHDataset(BaseDataset): ...@@ -51,23 +37,34 @@ class MATHDataset(BaseDataset):
return retval return retval
dataset = DatasetDict()
data = json.load(open(path)) def remove_boxed(s):
raw_data = [] left = '\\boxed{'
for i in data.keys(): try:
raw_data.append({ assert s[:len(left)] == left
'problem': assert s[-1] == '}'
data[i]['problem'], return s[len(left):-1]
'solution': except Exception:
remove_boxed(last_boxed_only_string(data[i]['solution'])) return None
})
dataset['test'] = Dataset.from_list(raw_data)
dataset['train'] = Dataset.from_list(raw_data)
return dataset
@TEXT_POSTPROCESSORS.register_module('math_postprocess') def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
def math_postprocess(text: str) -> str: boxed_str = last_boxed_only_string(pred_str)
if boxed_str is None:
return None
answer = remove_boxed(boxed_str)
if answer is None:
return None
if strip_double_curly_brace:
match = re.match('^\{(.*)\}$', answer) # noqa: W605
if match:
answer = match.group(1)
return answer
def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question."""
# final_answer = final_answer.split('=')[-1]
SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''),
(r'\ ', ''), (' ', ''), ('mbox', 'text'), (r'\ ', ''), (' ', ''), ('mbox', 'text'),
(',\\text{and}', ','), ('\\text{and}', ','), (',\\text{and}', ','), ('\\text{and}', ','),
...@@ -81,11 +78,6 @@ def math_postprocess(text: str) -> str: ...@@ -81,11 +78,6 @@ def math_postprocess(text: str) -> str:
'\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!', '\\text{}', r'\mathrm{th}', r'^\circ', r'^{\circ}', r'\;', r',\!',
'{,}', '"', '\\dots', '\n', '\r', '\f' '{,}', '"', '\\dots', '\n', '\r', '\f'
] ]
import re
def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question."""
# final_answer = final_answer.split('=')[-1]
for before, after in SUBSTITUTIONS: for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after) final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS: for expr in REMOVED_EXPRESSIONS:
...@@ -103,6 +95,9 @@ def math_postprocess(text: str) -> str: ...@@ -103,6 +95,9 @@ def math_postprocess(text: str) -> str:
if len(re.findall(r'finalansweris(.*)', final_answer)) > 0: if len(re.findall(r'finalansweris(.*)', final_answer)) > 0:
final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1] final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1]
if len(re.findall(r'answer?is:?(.*)', final_answer)) > 0:
final_answer = re.findall(r'answer?is:?(.*)', final_answer)[-1]
if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0: if len(re.findall(r'oxed\{(.*?)\}', final_answer)) > 0:
final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1] final_answer = re.findall(r'oxed\{(.*?)\}', final_answer)[-1]
...@@ -118,8 +113,7 @@ def math_postprocess(text: str) -> str: ...@@ -118,8 +113,7 @@ def math_postprocess(text: str) -> str:
# \fracabc -> \frac{a}{b}c # \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a} # \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b # \sqrtab -> sqrt{a}b
final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}', final_answer = re.sub(r'(frac)([^{])(.)', 'frac{\\2}{\\3}', final_answer)
final_answer)
final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer) final_answer = re.sub(r'(sqrt)([^{])', 'sqrt{\\2}', final_answer)
final_answer = final_answer.replace('$', '') final_answer = final_answer.replace('$', '')
...@@ -129,6 +123,30 @@ def math_postprocess(text: str) -> str: ...@@ -129,6 +123,30 @@ def math_postprocess(text: str) -> str:
return final_answer return final_answer
@LOAD_DATASET.register_module()
class MATHDataset(BaseDataset):
@staticmethod
def load(path: str):
dataset = DatasetDict()
data = json.load(open(path))
raw_data = []
for i in data.keys():
raw_data.append({
'problem':
data[i]['problem'],
'solution':
extract_boxed_answer(data[i]['solution'])
})
dataset['test'] = Dataset.from_list(raw_data)
dataset['train'] = Dataset.from_list(raw_data)
return dataset
@TEXT_POSTPROCESSORS.register_module('math_postprocess')
def math_postprocess(text: str) -> str:
for maybe_ans in text.split('.'): for maybe_ans in text.split('.'):
if 'final answer' in maybe_ans.lower(): if 'final answer' in maybe_ans.lower():
return normalize_final_answer(maybe_ans) return normalize_final_answer(maybe_ans)
...@@ -137,9 +155,27 @@ def math_postprocess(text: str) -> str: ...@@ -137,9 +155,27 @@ def math_postprocess(text: str) -> str:
# text.split('Final Answer: ', 1)[-1].split('\n\n')[0]) # text.split('Final Answer: ', 1)[-1].split('\n\n')[0])
@TEXT_POSTPROCESSORS.register_module('math_postprocess_v2')
def math_postprocess_v2(text: str) -> str:
cand_ans = extract_boxed_answer(text, strip_double_curly_brace=True)
if cand_ans:
return cand_ans
for maybe_ans in text.split('.'):
# if 'final answer' in maybe_ans.lower():
if re.search('final answer|answer is', maybe_ans.lower()):
return normalize_final_answer(maybe_ans)
return normalize_final_answer(text.split('.')[0])
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class MATHEvaluator(BaseEvaluator): class MATHEvaluator(BaseEvaluator):
def __init__(self, version='v1'):
assert version in ['v1', 'v2']
self.version = version
def score(self, predictions, references): def score(self, predictions, references):
if len(predictions) != len(references): if len(predictions) != len(references):
return { return {
...@@ -166,7 +202,7 @@ class MATHEvaluator(BaseEvaluator): ...@@ -166,7 +202,7 @@ class MATHEvaluator(BaseEvaluator):
substrs = substrs[1:] substrs = substrs[1:]
for substr in substrs: for substr in substrs:
new_str += '\\frac' new_str += '\\frac'
if substr[0] == '{': if len(substr) > 0 and substr[0] == '{':
new_str += substr new_str += substr
else: else:
try: try:
...@@ -228,6 +264,10 @@ class MATHEvaluator(BaseEvaluator): ...@@ -228,6 +264,10 @@ class MATHEvaluator(BaseEvaluator):
new_string += new_substr new_string += new_substr
return new_string return new_string
def _fix_sqrt_v2(self, string):
_string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
return _string
def _strip_string(self, string): def _strip_string(self, string):
# linebreaks # linebreaks
string = string.replace('\n', '') string = string.replace('\n', '')
...@@ -295,6 +335,109 @@ class MATHEvaluator(BaseEvaluator): ...@@ -295,6 +335,109 @@ class MATHEvaluator(BaseEvaluator):
return string return string
def _strip_string_v2(self, string):
string = str(string).strip()
# linebreaks
string = string.replace('\n', '')
# right "."
string = string.rstrip('.')
# remove inverse spaces
string = string.replace('\\!', '')
string = string.replace('\\ ', '')
# replace \\ with \
string = string.replace('\\\\', '\\')
string = string.replace('\\\\', '\\')
# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')
# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')
# Remove unit: miles, dollars if after is not none
_string = re.sub(r'\\text{.*?}$', '', string).strip()
if _string != '' and _string != string:
string = _string
# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')
# remove dollar signs
string = string.replace('\\$', '')
string = string.replace('$', '')
string = string.replace('\\text', '')
string = string.replace('x\\in', '')
# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '') # noqa: W605
string = string.replace('%', '')
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively,
# add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# cdot
string = string.replace('\\cdot', '')
# inf
string = string.replace('infinity', '\\infty')
if '\\infty' not in string:
string = string.replace('inf', '\\infty')
string = string.replace('+\\inity', '\\infty')
# and
string = string.replace('and', '')
string = string.replace('\\mathbf', '')
# use regex to remove \mbox{...}
string = re.sub(r'\\mbox{.*?}', '', string)
# quote
string.replace("'", '')
string.replace('"', '')
# i, j
if 'j' in string and 'i' not in string:
string = string.replace('j', 'i')
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r'(\d+)\.0+([^\d])', r'\1\2', string)
string = re.sub(r'(\d+)\.0+$', r'\1', string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]
string = self._fix_sqrt_v2(string)
string = string.replace(' ', '')
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc.
# Even works with \frac1{72} (but not \frac{72}1).
# Also does a/b --> \\frac{a}{b}
string = self._fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple
# cases fix in case the model output is X/Y
string = self._fix_a_slash_b(string)
return string
def is_equiv(self, str1, str2, verbose=False): def is_equiv(self, str1, str2, verbose=False):
if str1 is None and str2 is None: if str1 is None and str2 is None:
print('WARNING: Both None') print('WARNING: Both None')
...@@ -302,16 +445,24 @@ class MATHEvaluator(BaseEvaluator): ...@@ -302,16 +445,24 @@ class MATHEvaluator(BaseEvaluator):
if str1 is None or str2 is None: if str1 is None or str2 is None:
return False return False
if self.version == 'v1':
strip_string_func = self._strip_string
elif self.version == 'v2':
strip_string_func = self._strip_string_v2
else:
raise NotImplementedError
try: try:
ss1 = self._strip_string(str1) ss1 = strip_string_func(str1)
ss2 = self._strip_string(str2) ss2 = strip_string_func(str2)
if verbose: if verbose:
print(ss1, ss2) print(ss1, ss2)
return ss1 == ss2 return ss1 == ss2
except: # noqa except Exception:
return str1 == str2 return str1 == str2
@ICL_EVALUATORS.register_module()
class MATHAgentEvaluator(MATHEvaluator): class MATHAgentEvaluator(MATHEvaluator):
"""math agent evaluator for soft condition. """math agent evaluator for soft condition.
...@@ -320,8 +471,9 @@ class MATHAgentEvaluator(MATHEvaluator): ...@@ -320,8 +471,9 @@ class MATHAgentEvaluator(MATHEvaluator):
Defaults to `PythonInterpreter`. Defaults to `PythonInterpreter`.
""" """
def __init__(self, action: str = 'PythonInterpreter'): def __init__(self, action: str = 'PythonInterpreter', version='v1'):
self.action = action self.action = action
super().__init__(version=version)
def soft_equal(self, pred, refer, step): def soft_equal(self, pred, refer, step):
try: try:
......
from opencompass.openicl import BaseEvaluator
def check(a, b):
return abs(float(a) - float(b)) < 1e-3
class Math401Evaluator(BaseEvaluator):
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
correct = 0
count = 0
details = []
for i, j in zip(predictions, references):
detail = {'pred': i, 'answer': j, 'correct': False}
count += 1
try:
if check(i, j):
correct += 1
detail['correct'] = True
except Exception:
pass
details.append(detail)
result = {'accuracy': 100 * correct / count, 'details': details}
return result
import csv import csv
import json
import os.path as osp import os.path as osp
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
...@@ -18,7 +19,7 @@ class NaturalQuestionDataset(BaseDataset): ...@@ -18,7 +19,7 @@ class NaturalQuestionDataset(BaseDataset):
dataset = DatasetDict() dataset = DatasetDict()
for split in ['dev', 'test']: for split in ['dev', 'test']:
filename = osp.join(path, f'nq-{split}.qa.csv') filename = osp.join(path, f'nq-{split}.qa.csv')
with open(filename) as f: with open(filename, 'r', encoding='utf-8') as f:
reader = csv.reader(f, delimiter='\t') reader = csv.reader(f, delimiter='\t')
raw_data = [] raw_data = []
for row in reader: for row in reader:
...@@ -33,6 +34,26 @@ class NaturalQuestionDataset(BaseDataset): ...@@ -33,6 +34,26 @@ class NaturalQuestionDataset(BaseDataset):
return dataset return dataset
@LOAD_DATASET.register_module()
class NQOpenDataset(BaseDataset):
@staticmethod
def load(path: str):
dataset = DatasetDict()
for split in ['validation', 'train']:
filename = osp.join(path, f'nq-open-{split}.jsonl')
raw_data = []
with open(filename, 'r', encoding='utf-8') as f:
for doc in f:
doc = json.loads(doc)
if split == 'train':
doc['answer'] = doc['answer'][0]
raw_data.append(doc)
dataset[split] = Dataset.from_list(raw_data)
return dataset
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class NQEvaluator(BaseEvaluator): class NQEvaluator(BaseEvaluator):
......
...@@ -16,13 +16,13 @@ class ReasonBenchDataset(BaseDataset): ...@@ -16,13 +16,13 @@ class ReasonBenchDataset(BaseDataset):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = json.loads(line) line = json.loads(line)
prompt = line['prompt'] prompt = line.get('prompt', '')
prompt_ppl = line['prompt_ppl'] prompt_ppl = line.get('prompt_ppl', '')
label = line['label'] label = line.get('label', '')
label_ppl = line['label_ppl'] label_ppl = line.get('label_ppl', '')
choices = line['choices'] choices = line.get('choices', '')
tag = line['tag'] tag = line.get('tag', '')
source = line['source'] source = line.get('source', '')
option_content = {choice: line[choice] for choice in choices} option_content = {choice: line[choice] for choice in choices}
data = { data = {
'prompt': prompt, 'prompt': prompt,
......
import csv import csv
import json
import os.path as osp import os.path as osp
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
...@@ -18,7 +19,7 @@ class TriviaQADataset(BaseDataset): ...@@ -18,7 +19,7 @@ class TriviaQADataset(BaseDataset):
dataset = DatasetDict() dataset = DatasetDict()
for split in ['dev', 'test']: for split in ['dev', 'test']:
filename = osp.join(path, f'trivia-{split}.qa.csv') filename = osp.join(path, f'trivia-{split}.qa.csv')
with open(filename) as f: with open(filename, 'r', encoding='utf-8') as f:
reader = csv.reader(f, delimiter='\t') reader = csv.reader(f, delimiter='\t')
raw_data = [] raw_data = []
for row in reader: for row in reader:
...@@ -32,20 +33,49 @@ class TriviaQADataset(BaseDataset): ...@@ -32,20 +33,49 @@ class TriviaQADataset(BaseDataset):
return dataset return dataset
@LOAD_DATASET.register_module()
class TriviaQADataset_V2(BaseDataset):
@staticmethod
def load(path: str):
dataset = DatasetDict()
for split in ['validation', 'train']:
filename = osp.join(path, f'triviaqa-{split}.jsonl')
raw_data = []
with open(filename, 'r', encoding='utf-8') as f:
for doc in f:
doc = json.loads(doc)
raw_data.append(doc)
dataset[split] = Dataset.from_list(raw_data)
return dataset
@LOAD_DATASET.register_module()
class TriviaQADataset_V3(BaseDataset):
@staticmethod
def load(path: str):
data_list = []
with open(path, 'r', encoding='utf-8') as f:
for doc in f:
data_list.append(json.loads(doc))
return Dataset.from_list(data_list)
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class TriviaQAEvaluator(BaseEvaluator): class TriviaQAEvaluator(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
if len(predictions) != len(references): if len(predictions) != len(references):
return { return {'error': 'preds and refrs have different length'}
'error': 'predictions and references have different '
'length'
}
processed_predictions = [] processed_predictions = []
for prediction in predictions: for prediction in predictions:
prediction = prediction.strip().split('\n')[0].lower() prediction = prediction.strip().split('\n')[0].lower()
if 'answer is' in prediction:
prediction = prediction.split('answer is')[-1] prediction = prediction.split('answer is')[-1]
prediction = prediction.split('a:')[-1]
prediction = prediction.split('answer:')[-1]
prediction = prediction.strip()
prediction = general_postprocess(prediction) prediction = general_postprocess(prediction)
processed_predictions.append(prediction) processed_predictions.append(prediction)
processed_answers = [[general_postprocess(j).lower() for j in i] processed_answers = [[general_postprocess(j).lower() for j in i]
......
...@@ -16,11 +16,14 @@ from jupyter_client import KernelManager ...@@ -16,11 +16,14 @@ from jupyter_client import KernelManager
from lagent.actions.base_action import BaseAction from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode from lagent.schema import ActionReturn, ActionStatusCode
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace') WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR',
f"{os.path.abspath('./output_images')}")
DEFAULT_DESCRIPTION = """启动Jupter Kernel用于执行Python代码。""" DEFAULT_DESCRIPTION = """启动Jupter Kernel用于执行Python代码。"""
START_CODE = """ START_CODE = """
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def input(*args, **kwargs): def input(*args, **kwargs):
raise NotImplementedError('Python input() function is disabled.') raise NotImplementedError('Python input() function is disabled.')
...@@ -74,6 +77,10 @@ class IPythonInterpreter(BaseAction): ...@@ -74,6 +77,10 @@ class IPythonInterpreter(BaseAction):
if user_data_dir: if user_data_dir:
# user_data_dir = os.path.dirname(user_data_dir) # user_data_dir = os.path.dirname(user_data_dir)
# in case change of dirs
assert os.path.exists(user_data_dir), \
f'{user_data_dir} does not exist.'
user_data_dir = os.path.abspath(user_data_dir)
user_data_dir = f"import os\nos.chdir('{user_data_dir}')" user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
self.user_data_dir = user_data_dir self.user_data_dir = user_data_dir
self._initialized = False self._initialized = False
......
...@@ -24,5 +24,6 @@ from .qwen_api import Qwen # noqa: F401 ...@@ -24,5 +24,6 @@ from .qwen_api import Qwen # noqa: F401
from .sensetime_api import SenseTime # noqa: F401 from .sensetime_api import SenseTime # noqa: F401
from .turbomind import TurboMindModel # noqa: F401 from .turbomind import TurboMindModel # noqa: F401
from .turbomind_tis import TurboMindTisModel # noqa: F401 from .turbomind_tis import TurboMindTisModel # noqa: F401
from .vllm import VLLM # noqa: F401
from .xunfei_api import XunFei # noqa: F401 from .xunfei_api import XunFei # noqa: F401
from .zhipuai_api import ZhiPuAI # noqa: F401 from .zhipuai_api import ZhiPuAI # noqa: F401
...@@ -2,6 +2,9 @@ from abc import abstractmethod ...@@ -2,6 +2,9 @@ from abc import abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch
from mmengine import dist
from opencompass.utils.prompt import PromptList from opencompass.utils.prompt import PromptList
PromptType = Union[PromptList, str] PromptType = Union[PromptList, str]
...@@ -21,6 +24,9 @@ class BaseModel: ...@@ -21,6 +24,9 @@ class BaseModel:
wrapping of any meta instructions. wrapping of any meta instructions.
generation_kwargs (Dict, optional): The generation kwargs for the generation_kwargs (Dict, optional): The generation kwargs for the
model. Defaults to dict(). model. Defaults to dict().
sync_rank (bool): Whether to sync inputs between ranks. Do not use this
if you are not familiar with this behavior. Check `sync_inputs`
function for more details. Defaults to False.
""" """
is_api: bool = False is_api: bool = False
...@@ -30,7 +36,8 @@ class BaseModel: ...@@ -30,7 +36,8 @@ class BaseModel:
max_seq_len: int = 2048, max_seq_len: int = 2048,
tokenizer_only: bool = False, tokenizer_only: bool = False,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
generation_kwargs: Optional[Dict] = dict()): generation_kwargs: Optional[Dict] = dict(),
sync_rank: bool = False):
self.path = path self.path = path
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only self.tokenizer_only = tokenizer_only
...@@ -40,6 +47,7 @@ class BaseModel: ...@@ -40,6 +47,7 @@ class BaseModel:
if meta_template and 'eos_token_id' in meta_template: if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id'] self.eos_token_id = meta_template['eos_token_id']
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.sync_rank = sync_rank
@abstractmethod @abstractmethod
def generate(self, inputs: List[str], max_out_len: int) -> List[str]: def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
...@@ -77,6 +85,34 @@ class BaseModel: ...@@ -77,6 +85,34 @@ class BaseModel:
' ppl-based evaluation yet, try gen-based ' ' ppl-based evaluation yet, try gen-based '
'instead.') 'instead.')
@abstractmethod
def encode(self, prompt: str) -> torch.Tensor:
"""Encode prompt to tokens. Not necessary for most cases.
Args:
prompt (str): Input string.
Returns:
torch.Tensor: Encoded tokens.
"""
raise NotImplementedError(
f'{self.__class__.__name__} does not implement'
'`encode` method.')
@abstractmethod
def decode(self, tokens: torch.Tensor) -> str:
"""Decode tokens to text. Not necessary for most cases.
Args:
tokens (torch.Tensor): Input tokens.
Returns:
str: Decoded text.
"""
raise NotImplementedError(
f'{self.__class__.__name__} does not implement'
'`decode` method.')
@abstractmethod @abstractmethod
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings. """Get lengths of the tokenized strings.
...@@ -115,20 +151,6 @@ class BaseModel: ...@@ -115,20 +151,6 @@ class BaseModel:
inputs = self.parse_template(templates, mode='ppl') inputs = self.parse_template(templates, mode='ppl')
return self.get_ppl(inputs, mask_length) return self.get_ppl(inputs, mask_length)
def get_loglikelihood_from_template(self,
templates: List[PromptType],
conts: List[str],
mask_length=None):
"""Get perplexity given a list of templates.
Args:
templates (List[PromptType]): A list of templates.
mask_length (List[int]): A list of mask lengths. If provided, the
perplexity will be calculated only on the unmasked tokens.
"""
inputs = self.parse_template(templates, mode='ppl')
return self.get_loglikelihood(inputs, conts, mask_length)
def generate_from_template(self, templates: List[PromptType], def generate_from_template(self, templates: List[PromptType],
max_out_len: int, **kwargs): max_out_len: int, **kwargs):
"""Generate completion from a list of templates. """Generate completion from a list of templates.
...@@ -138,6 +160,8 @@ class BaseModel: ...@@ -138,6 +160,8 @@ class BaseModel:
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
""" """
inputs = self.parse_template(templates, mode='gen') inputs = self.parse_template(templates, mode='gen')
if hasattr(self, 'sync_rank') and self.sync_rank:
inputs = self.sync_inputs(inputs)
return self.generate(inputs, max_out_len=max_out_len, **kwargs) return self.generate(inputs, max_out_len=max_out_len, **kwargs)
def get_token_len_from_template( def get_token_len_from_template(
...@@ -165,6 +189,39 @@ class BaseModel: ...@@ -165,6 +189,39 @@ class BaseModel:
token_lens = [self.get_token_len(prompt) for prompt in prompts] token_lens = [self.get_token_len(prompt) for prompt in prompts]
return token_lens[0] if not is_batched else token_lens return token_lens[0] if not is_batched else token_lens
def sync_inputs(self, inputs: str) -> str:
"""For some case, when it involves multiprocessing with multiple gpus,
there might be the chance that inputs are different among different
gpus. Therefore, we need to sync inputs for rank0.
Args:
inputs (str): Inputs for each rank.
"""
rank = dist.get_rank()
if rank == 0:
tokens = self.encode(inputs)
length = self.get_token_len(inputs)
if length > 2048:
from opencompass.utils import get_logger
get_logger().info(f'Large tokens nums: {length}')
size = torch.tensor([tokens.shape], dtype=torch.long)
else:
tokens = None
size = torch.empty(2, dtype=torch.long)
# broadcast data size
dist.broadcast(size, src=0)
if rank != 0:
tokens = torch.empty(size.tolist(), dtype=torch.long)
# broadcast tokens
dist.broadcast(tokens, src=0)
# the final input might be different from original input
# due to the max sequence limitation
return self.decode(tokens)
def to(self, device): def to(self, device):
self.model.to(device) self.model.to(device)
......
...@@ -251,7 +251,8 @@ class HuggingFace(BaseModel): ...@@ -251,7 +251,8 @@ class HuggingFace(BaseModel):
**generation_kwargs) **generation_kwargs)
for input_ in inputs), []) for input_ in inputs), [])
def _batch_generate(self, inputs: List[str], def _batch_generate(self,
inputs: List[str],
max_out_len: int, max_out_len: int,
stopping_criteria: List[str] = [], stopping_criteria: List[str] = [],
**kwargs) -> List[str]: **kwargs) -> List[str]:
...@@ -295,7 +296,9 @@ class HuggingFace(BaseModel): ...@@ -295,7 +296,9 @@ class HuggingFace(BaseModel):
if stopping_criteria: if stopping_criteria:
# Construct huggingface stopping criteria # Construct huggingface stopping criteria
if self.tokenizer.eos_token is not None: if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token] stopping_criteria = stopping_criteria + [
self.tokenizer.eos_token
]
stopping_criteria = transformers.StoppingCriteriaList([ stopping_criteria = transformers.StoppingCriteriaList([
*[ *[
MultiTokenEOSCriteria(sequence, self.tokenizer, MultiTokenEOSCriteria(sequence, self.tokenizer,
...@@ -372,11 +375,12 @@ class HuggingFace(BaseModel): ...@@ -372,11 +375,12 @@ class HuggingFace(BaseModel):
max_length=self.max_seq_len - max_length=self.max_seq_len -
max_out_len)['input_ids'] max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device) input_ids = torch.tensor(input_ids, device=self.model.device)
if stopping_criteria: if stopping_criteria:
# Construct huggingface stopping criteria # Construct huggingface stopping criteria
if self.tokenizer.eos_token is not None: if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [self.tokenizer.eos_token] stopping_criteria = stopping_criteria + [
self.tokenizer.eos_token
]
stopping_criteria = transformers.StoppingCriteriaList([ stopping_criteria = transformers.StoppingCriteriaList([
*[ *[
MultiTokenEOSCriteria(sequence, self.tokenizer, MultiTokenEOSCriteria(sequence, self.tokenizer,
...@@ -523,11 +527,12 @@ class HuggingFace(BaseModel): ...@@ -523,11 +527,12 @@ class HuggingFace(BaseModel):
""" """
assert mask_length is None, 'Not support mask_length yet.' assert mask_length is None, 'Not support mask_length yet.'
if self.batch_padding and len(inputs) > 1: if self.batch_padding and len(inputs) > 1:
raise NotImplementedError('Batch padding is not supported yet.') assert self.tokenizer.pad_token
# assert self.tokenizer.pad_token return self._get_loglikelihood(inputs, conts)
# return self._get_loglikelihood(inputs, mask_length=mask_length) else:
return np.array([ return np.concatenate([
self._get_loglikelihood(inputs=inputs[idx], conts=conts[idx]) self._get_loglikelihood(inputs=[inputs[idx]],
conts=[conts[idx]])
for idx in range(len(inputs)) for idx in range(len(inputs))
]) ])
...@@ -540,32 +545,76 @@ class HuggingFace(BaseModel): ...@@ -540,32 +545,76 @@ class HuggingFace(BaseModel):
Returns: Returns:
float: loglikelihood scores. float: loglikelihood scores.
""" """
input_tokenizer_out = self.tokenizer(inputs,
input_ids = self.tokenizer(inputs, padding=True,
padding=False, truncation=False,
truncation=True, return_length=True,
max_length=self.max_seq_len)['input_ids'] return_tensors='pt').to(
input_ids = torch.tensor(input_ids, device=self.model.device) self.model.device)
context_ids = self.tokenizer(inputs.replace(conts, ''),
input_ids = input_tokenizer_out['input_ids'][:, :self.max_seq_len]
input_length = input_tokenizer_out['length']
attention_mask = input_tokenizer_out['attention_mask']
context_ids = [
self.tokenizer(inputs[i].replace(conts[i], ''),
padding=False, padding=False,
truncation=True, truncation=True,
max_length=self.max_seq_len)['input_ids'] max_length=self.max_seq_len)['input_ids']
cont_ids = input_ids[len(context_ids):] for i in range(len(inputs))
]
output = self.model(input_ids.unsqueeze(0)) # forward
logits = output['logits'][:, :-1] outputs = self.model(input_ids, attention_mask)['logits']
logits = torch.nn.functional.log_softmax(logits, dim=-1) outputs = torch.nn.functional.log_softmax(outputs, dim=-1)
contlen = cont_ids.shape[0] # calculate loglikelihood
logits = logits[:, -contlen:, :] answer = np.zeros(len(inputs))
for i in range(len(inputs)):
if self.tokenizer.padding_side == 'right':
cont_ids = input_ids[i, len(context_ids[i]):input_length[i]]
logits = outputs[i,
len(context_ids[i]) - 1:input_length[i] -
1, :] # noqa
else:
cont_ids = input_ids[i, len(context_ids[i]) - input_length[i]:]
logits = outputs[i,
len(context_ids[i]) - input_length[i] - 1:-1]
# Reducing the dimension will lead to a wrong outcome # Reducing the dimension will lead to a wrong outcome
logits_gather = torch.gather( logits_gather = torch.gather(
logits, 2, logits.unsqueeze(0), 2,
cont_ids.unsqueeze(0).unsqueeze(-1)) # [1, seq] cont_ids.unsqueeze(0).unsqueeze(-1)) # [1, seq]
# Answer: sum the likelihood of each token in continuation # Answer: sum the likelihood of each token in continuation
answer = float(logits_gather.detach().cpu().sum()) answer[i] = float(logits_gather.detach().cpu().sum())
return answer return answer
def get_mink_percent(self, inputs: List[str], k: int = 20) -> List[float]:
"""https://swj0419.github.io/detect-pretrain.github.io/"""
if self.batch_padding and len(inputs) > 1:
assert self.tokenizer.pad_token
return self._get_mink_percent(inputs, k=k)
else:
return np.concatenate([
self._get_mink_percent(inputs=[text], k=k) for text in inputs
])
def _get_mink_percent(self, inputs: List[str], k: int = 20) -> List[float]:
outputs, inputs = self.get_logits(inputs)
shift_logits = outputs[:, :-1, :].contiguous().float()
shift_labels = inputs['tokens']['input_ids'][:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(
reduction='none', ignore_index=self.tokenizer.pad_token_id)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)).view(shift_labels.size())
lens = (inputs['tokens']['input_ids'] !=
self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
mink_percent = []
for nloss, nlen in zip(loss, lens):
nlen = max(int(nlen) * k // 100, 1)
nloss = torch.topk(loss, nlen, dim=-1)[0]
nloss = -nloss.mean().cpu().detach().numpy()
mink_percent.append(nloss)
return np.array(mink_percent)
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings. """Get lengths of the tokenized strings.
...@@ -710,17 +759,15 @@ class HuggingFaceChatGLM3(HuggingFace): ...@@ -710,17 +759,15 @@ class HuggingFaceChatGLM3(HuggingFace):
responses.append('') responses.append('')
continue continue
try:
response, history = self.model.chat(self.tokenizer, response, history = self.model.chat(self.tokenizer,
user_content, user_content,
history=history, history=history,
max_new_tokens=max_out_len,
**generation_kwargs) **generation_kwargs)
# response will be dict sometime # response will be dict sometime
if isinstance(response, dict): if isinstance(response, dict):
response = response.get('content', '') response = response.get('content', '')
responses.append(response) responses.append(response)
except Exception:
responses.append('')
return responses return responses
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
......
...@@ -100,6 +100,42 @@ class Llama2(BaseModel): ...@@ -100,6 +100,42 @@ class Llama2(BaseModel):
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
return ce_loss return ce_loss
def get_loglikelihood(
self,
inputs: List[str],
conts: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
assert mask_length is None, 'mask_length is not supported'
bsz = len(inputs)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
# tokenize
input_tokens = [self.tokenizer.encode(x, True, False) for x in inputs]
max_prompt_size = max([len(t) for t in input_tokens])
total_len = min(params.max_seq_len, max_prompt_size)
tokens = torch.zeros((bsz, total_len)).cuda().long()
num_token_list = []
cont_tokens = []
for k, t in enumerate(input_tokens):
num_token = min(total_len, len(t))
num_token_list.append(num_token - 1)
tokens[k, :num_token] = torch.tensor(t[-num_token:]).long()
context_ids = self.tokenizer.encode(
inputs[k].replace(conts[k], ''), True, False)
cont_tokens.append(tokens[k, len(context_ids):num_token])
# forward
outputs = self.model.forward(tokens, 0)[:, :-1, :]
outputs = torch.nn.functional.log_softmax(outputs, dim=-1)
loglikelihood_sum = torch.zeros(bsz).cuda()
for idx in range(bsz):
logits = outputs[
idx, num_token_list[idx] -
len(cont_tokens[idx]):num_token_list[idx], :].unsqueeze(0)
loglikelihood_sum[idx] = torch.gather(
logits, 2, cont_tokens[idx].unsqueeze(0).unsqueeze(-1)).sum()
loglikelihood_sum = loglikelihood_sum.cpu().detach().numpy()
return loglikelihood_sum
def get_token_len(self, prompt: str) -> int: def get_token_len(self, prompt: str) -> int:
return len(self.tokenizer.encode(prompt, True, True)) return len(self.tokenizer.encode(prompt, True, True))
...@@ -115,6 +151,7 @@ class Llama2Chat(BaseModel): ...@@ -115,6 +151,7 @@ class Llama2Chat(BaseModel):
tokenizer_only (bool): whether to load tokenizer only tokenizer_only (bool): whether to load tokenizer only
tokenizer_path (str): path to the tokenizer directory tokenizer_path (str): path to the tokenizer directory
meta_template (dict): meta template for the model meta_template (dict): meta template for the model
force_bf16 (bool): whether to force set model to `bfloat16`
""" """
def __init__( def __init__(
...@@ -125,6 +162,7 @@ class Llama2Chat(BaseModel): ...@@ -125,6 +162,7 @@ class Llama2Chat(BaseModel):
tokenizer_only: bool = False, tokenizer_only: bool = False,
tokenizer_path: Optional[str] = None, tokenizer_path: Optional[str] = None,
meta_template: Optional[Dict] = None, meta_template: Optional[Dict] = None,
force_bf16: bool = False,
): # noqa ): # noqa
if tokenizer_only: if tokenizer_only:
self._load_tokenizer(tokenizer_path=tokenizer_path) self._load_tokenizer(tokenizer_path=tokenizer_path)
...@@ -132,7 +170,8 @@ class Llama2Chat(BaseModel): ...@@ -132,7 +170,8 @@ class Llama2Chat(BaseModel):
self._load_model(path=path, self._load_model(path=path,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
tokenizer_path=tokenizer_path) tokenizer_path=tokenizer_path,
force_bf16=force_bf16)
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.template_parser = APITemplateParser(meta_template) self.template_parser = APITemplateParser(meta_template)
self.logger = get_logger() self.logger = get_logger()
......
from typing import Dict, List, Optional
from opencompass.models.base import BaseModel
from opencompass.utils import get_logger
try:
from vllm import LLM, SamplingParams
except ImportError:
LLM, SamplingParams = None, None
DEFAULT_MODEL_KWARGS = dict(trust_remote_code=True)
class VLLM(BaseModel):
"""Model Wrapper for VLLM."""
def __init__(
self,
path: str,
max_seq_len: int = 2048,
model_kwargs: dict = None,
generation_kwargs: dict = dict(),
meta_template: Optional[Dict] = None,
mode: str = 'none',
use_fastchat_template: bool = False,
end_str: Optional[str] = None,
):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
assert LLM, ('Please install VLLM with `pip install vllm`. '
'note: torch==2.1.2 is required.')
self.logger = get_logger()
self._load_model(path, model_kwargs)
self.tokenizer = self.model.get_tokenizer()
self.generation_kwargs = generation_kwargs
self.generation_kwargs.pop('do_sample', None)
assert mode in ['none', 'mid']
self.mode = mode
self.use_fastchat_template = use_fastchat_template
self.end_str = end_str
def _load_model(self,
path: str,
add_model_kwargs: dict = None,
num_retry: int = 3):
model_kwargs = DEFAULT_MODEL_KWARGS.copy()
if add_model_kwargs is not None:
model_kwargs.update(add_model_kwargs)
self.model = LLM(path, **model_kwargs)
def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
if self.mode == 'mid':
input_ids = self.tokenizer(inputs, truncation=False)['input_ids']
inputs = []
for input_id in input_ids:
if len(input_id) > self.max_seq_len - max_out_len:
half = int((self.max_seq_len - max_out_len) / 2)
inputs.append(
self.tokenizer.decode(input_id[:half],
skip_special_tokens=True) +
self.tokenizer.decode(input_id[-half:],
skip_special_tokens=True))
else:
inputs.append(
self.tokenizer.decode(input_id,
skip_special_tokens=True))
generation_kwargs = kwargs.copy()
generation_kwargs.update(self.generation_kwargs)
generation_kwargs.update({'max_tokens': max_out_len})
sampling_kwargs = SamplingParams(**generation_kwargs)
outputs = self.model.generate(inputs, sampling_kwargs)
prompt_list, output_strs = [], []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
if self.end_str:
generated_text = generated_text.split(self.end_str)[0]
prompt_list.append(prompt)
output_strs.append(generated_text)
return output_strs
def prompts_preproccess(self, inputs: List[str]):
if self.use_fastchat_template:
try:
from fastchat.model import get_conversation_template
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Fastchat is not implemented. You can use '
"'pip install \"fschat[model_worker,webui]\"' "
'to implement fastchat.')
conv = get_conversation_template('vicuna')
conv.append_message(conv.roles[0], inputs[0])
conv.append_message(conv.roles[1], None)
inputs = [conv.get_prompt()]
return inputs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
return len(self.model.get_tokenizer().encode(prompt))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment