Unverified Commit f480b727 authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Feature] Support model-bound prediction postprocessor, use it in Claude (#268)



* [Feature] Support model-bound text postprocessor, add claude as an example

* update

* update

* minor fix

---------
Co-authored-by: default avatarzhoufengzhe <zhoufengzhe@pjlab.org.cn>
parent 6df124d4
from mmengine.config import read_base from mmengine.config import read_base
from opencompass.models.claude_api import Claude
from opencompass.partitioners import NaivePartitioner from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask from opencompass.tasks import OpenICLInferTask
...@@ -9,15 +8,7 @@ with read_base(): ...@@ -9,15 +8,7 @@ with read_base():
from .datasets.collections.chat_medium import datasets from .datasets.collections.chat_medium import datasets
# and output the results in a choosen format # and output the results in a choosen format
from .summarizers.medium import summarizer from .summarizers.medium import summarizer
from .models.claude import models
models = [
dict(abbr='Claude2',
type=Claude,
path='claude-2',
key='YOUR_CLAUDE_KEY',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=2),
]
infer = dict( infer = dict(
partitioner=dict(type=NaivePartitioner), partitioner=dict(type=NaivePartitioner),
......
from opencompass.models.claude_api.claude_api import Claude
from opencompass.utils.text_postprocessors import last_option_postprocess
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
agieval_single_choice_sets = [
'gaokao-chinese',
'gaokao-english',
'gaokao-geography',
'gaokao-history',
'gaokao-biology',
'gaokao-chemistry',
'gaokao-mathqa',
'logiqa-zh',
'lsat-ar',
'lsat-lr',
'lsat-rc',
'logiqa-en',
'sat-math',
'sat-en',
'sat-en-without-passage',
'aqua-rat',
]
agieval_multiple_choices_sets = [
'gaokao-physics',
'jec-qa-kd',
'jec-qa-ca',
]
claude_postprocessors = {
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
'bustm-*': dict(type=last_option_postprocess, options='AB'),
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
'piqa': dict(type=last_option_postprocess, options='AB'),
'race-*': dict(type=last_option_postprocess, options='ABCD'),
'summedits': dict(type=last_option_postprocess, options='AB'),
'BoolQ': dict(type=last_option_postprocess, options='AB'),
'CB': dict(type=last_option_postprocess, options='ABC'),
'MultiRC': dict(type=last_option_postprocess, options='AB'),
'RTE': dict(type=last_option_postprocess, options='AB'),
'WiC': dict(type=last_option_postprocess, options='AB'),
'WSC': dict(type=last_option_postprocess, options='AB'),
'winogrande': dict(type=last_option_postprocess, options='AB'),
'gsm8k': dict(type=gsm8k_postprocess),
'openai_humaneval': dict(type=humaneval_postprocess),
'lcsts': dict(type=lcsts_postprocess),
'mbpp': dict(type=mbpp_postprocess),
'strategyqa': dict(type=strategyqa_pred_postprocess),
}
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
models = [
dict(abbr='Claude',
type=Claude,
path='claude-1',
key='YOUR_CLAUDE_KEY',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=2,
pred_postprocessor=claude_postprocessors,
),
]
from opencompass.models.claude_api.claude_api import Claude
from opencompass.utils.text_postprocessors import last_option_postprocess
from opencompass.models.claude_api.postprocessors import gsm8k_postprocess, humaneval_postprocess, lcsts_postprocess, mbpp_postprocess, strategyqa_pred_postprocess
agieval_single_choice_sets = [
'gaokao-chinese',
'gaokao-english',
'gaokao-geography',
'gaokao-history',
'gaokao-biology',
'gaokao-chemistry',
'gaokao-mathqa',
'logiqa-zh',
'lsat-ar',
'lsat-lr',
'lsat-rc',
'logiqa-en',
'sat-math',
'sat-en',
'sat-en-without-passage',
'aqua-rat',
]
agieval_multiple_choices_sets = [
'gaokao-physics',
'jec-qa-kd',
'jec-qa-ca',
]
claude_postprocessors = {
'ceval-*': dict(type=last_option_postprocess, options='ABCD'),
'bustm-*': dict(type=last_option_postprocess, options='AB'),
'hellaswag': dict(type=last_option_postprocess, options='ABCD'),
'lukaemon_mmlu_*': dict(type=last_option_postprocess, options='ABCD'),
'openbookqa*': dict(type=last_option_postprocess, options='ABCD'),
'piqa': dict(type=last_option_postprocess, options='AB'),
'race-*': dict(type=last_option_postprocess, options='ABCD'),
'summedits': dict(type=last_option_postprocess, options='AB'),
'BoolQ': dict(type=last_option_postprocess, options='AB'),
'CB': dict(type=last_option_postprocess, options='ABC'),
'MultiRC': dict(type=last_option_postprocess, options='AB'),
'RTE': dict(type=last_option_postprocess, options='AB'),
'WiC': dict(type=last_option_postprocess, options='AB'),
'WSC': dict(type=last_option_postprocess, options='AB'),
'winogrande': dict(type=last_option_postprocess, options='AB'),
'gsm8k': dict(type=gsm8k_postprocess),
'openai_humaneval': dict(type=humaneval_postprocess),
'lcsts': dict(type=lcsts_postprocess),
'mbpp': dict(type=mbpp_postprocess),
'strategyqa': dict(type=strategyqa_pred_postprocess),
}
for _name in agieval_multiple_choices_sets + agieval_single_choice_sets:
claude_postprocessors[f'agieval-{_name}'] = dict(type=last_option_postprocess, options='ABCDE')
models = [
dict(abbr='Claude2',
type=Claude,
path='claude-2',
key='YOUR_CLAUDE_KEY',
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=2,
pred_postprocessor=claude_postprocessors,
),
]
from .base import BaseModel, LMTemplateParser # noqa from .base import BaseModel, LMTemplateParser # noqa
from .base_api import APITemplateParser, BaseAPIModel # noqa from .base_api import APITemplateParser, BaseAPIModel # noqa
from .claude_api import Claude # noqa: F401
from .glm import GLM130B # noqa: F401, F403 from .glm import GLM130B # noqa: F401, F403
from .huggingface import HuggingFace # noqa: F401, F403 from .huggingface import HuggingFace # noqa: F401, F403
from .huggingface import HuggingFaceCausalLM # noqa: F401, F403 from .huggingface import HuggingFaceCausalLM # noqa: F401, F403
......
from .claude_api import Claude
__all__ = ['Claude']
...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
from opencompass.registry import MODELS from opencompass.registry import MODELS
from opencompass.utils import PromptList from opencompass.utils import PromptList
from .base_api import BaseAPIModel from ..base_api import BaseAPIModel
PromptType = Union[PromptList, str] PromptType = Union[PromptList, str]
......
import re
def gsm8k_postprocess(text: str) -> str:
text = text.split(' ')[::-1]
flag = False
ret = ''
for i in range(len(text)):
s = text[i]
for i in range(len(s)):
if s[i].isdigit():
flag = True
ret = s
break
if flag:
break
ret1 = ''
for i in range(len(ret)):
if ret[i].isdigit():
ret1 += ret[i]
return ret1
def humaneval_postprocess(text: str) -> str:
text = '\n'.join(text.split('\n')[1:]).strip()
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```python
text = text[max(text.find('\n') + 1, 0):]
if text.strip().startswith('from') or text.strip().startswith('import'):
def_idx = text.find('def')
if def_idx != -1:
text = text[max(text.find('\n', def_idx) + 1, 0):]
if text.strip().startswith('def'):
text = '\n'.join(text.split('\n')[1:])
if not text.startswith(' '):
if text.startswith(' '):
text = ' ' + text.lstrip()
else:
text = '\n'.join([' ' + line for line in text.split('\n')])
return text
def lcsts_postprocess(text: str) -> str:
text = text.strip()
text = text.replace('1. ', '') if text.startswith('1. ') else text
text = text.replace('- ', '') if text.startswith('- ') else text
text = text.strip('“,。!”')
return text
def mbpp_postprocess(text: str) -> str:
if text.startswith('Here'):
text = '\n'.join(text.split('\n')[1:]).strip()
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```python
text = text[max(text.find('\n') + 1, 0):]
return text
def strategyqa_pred_postprocess(text: str) -> str:
if text.startswith('Here'):
text = '\n'.join(text.split('\n')[1:]).strip()
text = text.split('answer is ')[-1]
match = re.search(r'(yes|no)', text.lower())
if match:
return match.group(1)
return ''
import argparse import argparse
import fnmatch
import os.path as osp import os.path as osp
import time import time
from collections import Counter from collections import Counter
...@@ -11,8 +12,9 @@ from mmengine.utils import mkdir_or_exist ...@@ -11,8 +12,9 @@ from mmengine.utils import mkdir_or_exist
from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS, from opencompass.registry import (ICL_EVALUATORS, MODELS, TASKS,
TEXT_POSTPROCESSORS) TEXT_POSTPROCESSORS)
from opencompass.tasks.base import BaseTask from opencompass.tasks.base import BaseTask
from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path, from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_logger, task_abbr_from_cfg) get_infer_output_path, get_logger,
task_abbr_from_cfg)
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
...@@ -47,6 +49,17 @@ class OpenICLEvalTask(BaseTask): ...@@ -47,6 +49,17 @@ class OpenICLEvalTask(BaseTask):
self.eval_cfg = self.dataset_cfg.get('eval_cfg') self.eval_cfg = self.dataset_cfg.get('eval_cfg')
self.output_column = dataset_cfg['reader_cfg']['output_column'] self.output_column = dataset_cfg['reader_cfg']['output_column']
# overwrite postprocessor if the model has specified one
ds_abbr = dataset_abbr_from_cfg(self.dataset_cfg)
model_postprocessors = self.model_cfg.get(
'pred_postprocessor', {})
for pattern in model_postprocessors.keys():
if fnmatch.fnmatch(ds_abbr, pattern):
self.eval_cfg[
'pred_postprocessor'] = model_postprocessors[
pattern] # noqa
break
out_path = get_infer_output_path( out_path = get_infer_output_path(
self.model_cfg, self.dataset_cfg, self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results')) osp.join(self.work_dir, 'results'))
......
...@@ -19,4 +19,5 @@ def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict: ...@@ -19,4 +19,5 @@ def build_model_from_cfg(model_cfg: ConfigDict) -> ConfigDict:
model_cfg.pop('max_out_len', None) model_cfg.pop('max_out_len', None)
model_cfg.pop('batch_size', None) model_cfg.pop('batch_size', None)
model_cfg.pop('abbr', None) model_cfg.pop('abbr', None)
model_cfg.pop('pred_postprocessor', None)
return MODELS.build(model_cfg) return MODELS.build(model_cfg)
...@@ -79,3 +79,10 @@ def first_capital_postprocess_multi(text: str) -> str: ...@@ -79,3 +79,10 @@ def first_capital_postprocess_multi(text: str) -> str:
if match: if match:
return match.group(1) return match.group(1)
return '' return ''
def last_option_postprocess(text: str, options: str) -> str:
match = re.findall(rf'([{options}])', text)
if match:
return match[-1]
return ''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment