Unverified Commit b4afe3e7 authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Sync] Add InternLM2 Keyset Evaluation Demo (#807)


Co-authored-by: default avatarzhangyifan1 <zhangyifan1@pjlab.org.cn>
parent acae5609
summarizer = dict( summarizer = dict(
dataset_abbrs = [ dataset_abbrs = [
'--------- LongBench Single-Document QA ---------', # category '--------- LongBench Single-Document QA ---------', # category
"LongBench_narrativeqa", 'LongBench_narrativeqa',
'LongBench_qasper', 'LongBench_qasper',
'LongBench_multifieldqa_en', 'LongBench_multifieldqa_en',
"LongBench_multifieldqa_zh", 'LongBench_multifieldqa_zh',
'--------- LongBench Multi-Document QA ---------', # category '--------- LongBench Multi-Document QA ---------', # category
'LongBench_hotpotqa', 'LongBench_hotpotqa',
'LongBench_2wikimqa', 'LongBench_2wikimqa',
...@@ -28,5 +28,5 @@ summarizer = dict( ...@@ -28,5 +28,5 @@ summarizer = dict(
'LongBench_lcc', 'LongBench_lcc',
'LongBench_repobench-p', 'LongBench_repobench-p',
], ],
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []), summary_groups=sum([v for k, v in locals().items() if k.endswith('_summary_groups')], []),
) )
...@@ -13,7 +13,9 @@ from .commonsenseqa import commonsenseqaDataset ...@@ -13,7 +13,9 @@ from .commonsenseqa import commonsenseqaDataset
from .hellaswag import hellaswagDataset_V2 from .hellaswag import hellaswagDataset_V2
from .mmlu import MMLUDataset from .mmlu import MMLUDataset
from .obqa import OBQADataset from .obqa import OBQADataset
from .piqa import piqaDataset_V2
from .race import RaceDataset from .race import RaceDataset
from .siqa import siqaDataset_V3
from .xiezhi import XiezhiDataset from .xiezhi import XiezhiDataset
...@@ -273,6 +275,24 @@ class CircularXiezhiDataset(XiezhiDataset, metaclass=CircularDatasetMeta): ...@@ -273,6 +275,24 @@ class CircularXiezhiDataset(XiezhiDataset, metaclass=CircularDatasetMeta):
default_answer_key = 'answer' default_answer_key = 'answer'
class CircularsiqaDataset(siqaDataset_V3, metaclass=CircularDatasetMeta):
dataset_class = siqaDataset_V3
default_circular_splits = ['validation']
default_option_keys = ['A', 'B', 'C']
default_answer_key = 'answer'
class CircularpiqaDataset(piqaDataset_V2, metaclass=CircularDatasetMeta):
dataset_class = piqaDataset_V2
default_circular_splits = ['validation']
default_option_keys = ['sol1', 'sol2']
def default_answer_key_switch_method(item, circular_pattern):
circular_pattern = tuple(int(i[-1]) - 1 for i in circular_pattern)
item['answer'] = 'AB'[circular_pattern['AB'.index(item['answer'])]]
return item
class CircularEvaluator(BaseEvaluator): class CircularEvaluator(BaseEvaluator):
"""This Evaluator assesses datasets post-Circular processing, generating """This Evaluator assesses datasets post-Circular processing, generating
the following evaluation metrics: the following evaluation metrics:
......
...@@ -378,6 +378,8 @@ class DS1000ServiceEvaluator(BaseEvaluator): ...@@ -378,6 +378,8 @@ class DS1000ServiceEvaluator(BaseEvaluator):
processed_predictions = {} processed_predictions = {}
assert len(predictions) == len(references) assert len(predictions) == len(references)
for i, (pred, gold) in enumerate(zip(predictions, references)): for i, (pred, gold) in enumerate(zip(predictions, references)):
if len(pred) > 10000:
pred = ''
processed_predictions[str(i)] = {'prediction': pred, 'gold': gold} processed_predictions[str(i)] = {'prediction': pred, 'gold': gold}
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
......
...@@ -155,6 +155,11 @@ def humaneval_postprocess(text: str) -> str: ...@@ -155,6 +155,11 @@ def humaneval_postprocess(text: str) -> str:
def humaneval_postprocess_v2(text: str) -> str: def humaneval_postprocess_v2(text: str) -> str:
"""This is an advanced version of previous postprocess to handle more """This is an advanced version of previous postprocess to handle more
situations, better to use this one.""" situations, better to use this one."""
try:
# for chatGLM raw text
text = eval(text)
except Exception:
pass
text = text.lstrip('\n') text = text.lstrip('\n')
if '```' in text: if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL) blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
...@@ -173,11 +178,11 @@ def humaneval_postprocess_v2(text: str) -> str: ...@@ -173,11 +178,11 @@ def humaneval_postprocess_v2(text: str) -> str:
text = text.lstrip('\n') text = text.lstrip('\n')
if text.strip().startswith('def'): if text.strip().startswith('def'):
text = '\n'.join(text.split('\n')[1:]) text = '\n'.join(text.split('\n')[1:])
if not text.startswith(' '): # deal with the indentation error
if text.startswith(' '): if text.startswith(' '):
text = ' ' + text.lstrip() text = ' ' + text.lstrip()
else: else:
text = '\n'.join([' ' + line for line in text.split('\n')]) text = '\n'.join([' ' + line for line in text.split('\n')])
text = text.split('\n') text = text.split('\n')
# If number of leading space reduces, we assume that the code block ends. # If number of leading space reduces, we assume that the code block ends.
......
...@@ -14,6 +14,7 @@ from datasets import Dataset ...@@ -14,6 +14,7 @@ from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.openicl.icl_evaluator import BaseEvaluator
from .base import BaseDataset from .base import BaseDataset
from .humaneval import humaneval_postprocess_v2
_LANGUAGE_NAME_DICT = { _LANGUAGE_NAME_DICT = {
'cpp': 'CPP', 'cpp': 'CPP',
...@@ -89,9 +90,11 @@ class HumanevalXEvaluator(BaseEvaluator): ...@@ -89,9 +90,11 @@ class HumanevalXEvaluator(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
predictions = [{ predictions = [{
'task_id': f'{_LANGUAGE_NAME_DICT[self.language]}/{i}', 'task_id':
'generation': _clean_up_code(pred, self.language), f'{_LANGUAGE_NAME_DICT[self.language]}/{i}',
} for i, pred in enumerate(predictions)] 'generation':
_clean_up_code(pred, self.language, refer),
} for i, (pred, refer) in enumerate(zip(predictions, references))]
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_out_path = osp.join(tmp_dir, tmp_out_path = osp.join(tmp_dir,
f'humanevalx_{self.language}.json') f'humanevalx_{self.language}.json')
...@@ -161,15 +164,28 @@ class HumanevalXEvaluator(BaseEvaluator): ...@@ -161,15 +164,28 @@ class HumanevalXEvaluator(BaseEvaluator):
return False, err return False, err
def _clean_up_code(text: str, language_type: str) -> str: def _clean_up_code(text: str, language_type: str, reference) -> str:
"""Cleans up the generated code.""" """Cleans up the generated code."""
try:
# for chatGLM related text
text = eval(text)
except Exception:
pass
# extract code from code block
text = text.lstrip('\n')
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```xxx
text = text[max(text.find('\n') + 1, 0):]
if language_type.lower() == 'python': if language_type.lower() == 'python':
text = humaneval_postprocess_v2(text)
# we need to take care of the first line # we need to take care of the first line
# append extra space for first line for correct indentation # append extra space for first line for correct indentation
for c_index, c in enumerate(text[:5]): text = ' ' + text.lstrip()
if c != ' ':
text = ' ' * (4 - c_index) + text
break
text_splits = text.split('\n') text_splits = text.split('\n')
is_empty_line = False is_empty_line = False
...@@ -189,7 +205,13 @@ def _clean_up_code(text: str, language_type: str) -> str: ...@@ -189,7 +205,13 @@ def _clean_up_code(text: str, language_type: str) -> str:
for w in end_words: for w in end_words:
if w in text: if w in text:
text = text[:text.rfind(w)] text = text[:text.rfind(w)]
elif language_type.lower() == 'java': # strip function head for all other language
func_name = reference.strip().split('\n')[-1]
if func_name:
func_name = func_name.strip().strip('{')
if func_name in text:
text = '\n'.join(text[text.find(func_name):].split('\n')[1:])
if language_type.lower() == 'java':
main_pos = text.find('public static void main') main_pos = text.find('public static void main')
if main_pos != -1: if main_pos != -1:
text = text[:main_pos] + '}' text = text[:main_pos] + '}'
......
...@@ -200,30 +200,28 @@ class MBPPEvaluator(BaseEvaluator): ...@@ -200,30 +200,28 @@ class MBPPEvaluator(BaseEvaluator):
def score(self, predictions, references): def score(self, predictions, references):
assert len(predictions) == len(references) assert len(predictions) == len(references)
predictions = [self._process_answer(pred) for pred in predictions]
if self.metric == 'MBPP': if self.metric == 'MBPP':
result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0} result = {'pass': 0, 'timeout': 0, 'failed': 0, 'wrong_answer': 0}
details = {} details = {}
for index, (test_case, # change to thread pool for better killing blocked instance
pred) in enumerate(zip(references, predictions)): with ThreadPoolExecutor() as executor:
programs = self._process_test(test_case, pred) futures = []
try: for i, (refer, pred) in enumerate(zip(references,
# Add exec globals to prevent the exec to raise predictions)):
# unnecessary NameError for correct answer pred = self._process_answer(pred)
exec_globals = {} programs = self._process_test(refer, pred)
with swallow_io(): future = executor.submit(execution, programs, i, 3)
with time_limit(2): futures.append(future)
exec(programs, exec_globals)
r = 'pass' from tqdm import tqdm
except TimeOutException: for future in tqdm(as_completed(futures), total=len(futures)):
r = 'timeout' index, key = future.result()
except AssertionError: result[key] += 1
r = 'wrong_answer' details[str(index)] = {
except BaseException: 'programs': predictions[index],
r = 'failed' 'result': key
result[r] += 1 }
details[str(index)] = {'programs': programs, 'result': r}
result['score'] = result['pass'] / len(predictions) * 100 result['score'] = result['pass'] / len(predictions) * 100
result['details'] = details result['details'] = details
...@@ -263,6 +261,20 @@ class MBPPEvaluator(BaseEvaluator): ...@@ -263,6 +261,20 @@ class MBPPEvaluator(BaseEvaluator):
return {f'mbpp_plus_{k}': score[k] * 100 for k in score} return {f'mbpp_plus_{k}': score[k] * 100 for k in score}
def _process_answer(self, text): def _process_answer(self, text):
try:
# for chatGLM related text
text = eval(text)
except Exception:
pass
# deal with code block
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```xxx
text = text[max(text.find('\n') + 1, 0):]
text = text.strip() text = text.strip()
match = re.search(r"('\s*|)(\[DONE\]|DONE)", text) match = re.search(r"('\s*|)(\[DONE\]|DONE)", text)
if match: if match:
...@@ -275,6 +287,10 @@ class MBPPEvaluator(BaseEvaluator): ...@@ -275,6 +287,10 @@ class MBPPEvaluator(BaseEvaluator):
text = text[1:] text = text[1:]
if text.endswith("'"): if text.endswith("'"):
text = text[:-1] text = text[:-1]
text = text.replace('\\', '')
match = re.search(r'```python(.*)```', text, re.DOTALL)
if match:
text = match.group(1).strip().split('```')[0].strip()
return text return text
def _process_test(self, test_case, pred): def _process_test(self, test_case, pred):
......
...@@ -78,3 +78,37 @@ class siqaDataset_V2(BaseDataset): ...@@ -78,3 +78,37 @@ class siqaDataset_V2(BaseDataset):
val_dataset = siqaDataset_V2.load_single(path, 'dev.jsonl', val_dataset = siqaDataset_V2.load_single(path, 'dev.jsonl',
'dev-labels.lst') 'dev-labels.lst')
return DatasetDict({'train': train_dataset, 'validation': val_dataset}) return DatasetDict({'train': train_dataset, 'validation': val_dataset})
@LOAD_DATASET.register_module()
class siqaDataset_V3(BaseDataset):
"""Disconnect from HuggingFace version of HFDataset."""
@staticmethod
def load_single(path, data_filename, label_filename):
data_path = os.path.join(path, data_filename)
label_path = os.path.join(path, label_filename)
dataset = []
with open(data_path, 'r', encoding='utf-8') as f:
data_lines = f.readlines()
with open(label_path, 'r', encoding='utf-8') as f:
label_lines = f.readlines()
assert len(data_lines) == len(label_lines)
for data, label in zip(data_lines, label_lines):
i = json.loads(data.strip())
i['A'] = i.pop('answerA')
i['B'] = i.pop('answerB')
i['C'] = i.pop('answerC')
i['answer'] = 'ABC'[int(label.strip()) - 1]
dataset.append(i)
return Dataset.from_list(dataset)
@staticmethod
def load(path):
train_dataset = siqaDataset_V3.load_single(path, 'train.jsonl',
'train-labels.lst')
val_dataset = siqaDataset_V3.load_single(path, 'dev.jsonl',
'dev-labels.lst')
return DatasetDict({'train': train_dataset, 'validation': val_dataset})
...@@ -57,6 +57,8 @@ class IPythonInterpreter(BaseAction): ...@@ -57,6 +57,8 @@ class IPythonInterpreter(BaseAction):
user_data_dir (str): Specified the user data directory for files user_data_dir (str): Specified the user data directory for files
loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
Defaults to `ENV`. Defaults to `ENV`.
force_user_data (bool): Whether to force use user data.
Defaults to True.
""" """
_KERNEL_CLIENTS = {} _KERNEL_CLIENTS = {}
...@@ -68,7 +70,8 @@ class IPythonInterpreter(BaseAction): ...@@ -68,7 +70,8 @@ class IPythonInterpreter(BaseAction):
disable_description: Optional[str] = None, disable_description: Optional[str] = None,
timeout: int = 20, timeout: int = 20,
trim_output: Optional[int] = 1024, trim_output: Optional[int] = 1024,
user_data_dir: str = 'ENV') -> None: user_data_dir: str = 'ENV',
force_user_data: bool = True) -> None:
super().__init__(description, name, enable, disable_description) super().__init__(description, name, enable, disable_description)
self.timeout = timeout self.timeout = timeout
...@@ -82,6 +85,11 @@ class IPythonInterpreter(BaseAction): ...@@ -82,6 +85,11 @@ class IPythonInterpreter(BaseAction):
f'{user_data_dir} does not exist.' f'{user_data_dir} does not exist.'
user_data_dir = os.path.abspath(user_data_dir) user_data_dir = os.path.abspath(user_data_dir)
user_data_dir = f"import os\nos.chdir('{user_data_dir}')" user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
else:
if force_user_data:
raise ValueError('user_data_dir is not set. Please '
'set force_user_data to False if '
'no extra data needed.')
self.user_data_dir = user_data_dir self.user_data_dir = user_data_dir
self._initialized = False self._initialized = False
self.trim_output = trim_output self.trim_output = trim_output
......
...@@ -225,6 +225,7 @@ class HuggingFace(BaseModel): ...@@ -225,6 +225,7 @@ class HuggingFace(BaseModel):
def generate(self, def generate(self,
inputs: List[str], inputs: List[str],
max_out_len: int, max_out_len: int,
min_out_len: Optional[int] = None,
stopping_criteria: List[str] = [], stopping_criteria: List[str] = [],
**kwargs) -> List[str]: **kwargs) -> List[str]:
"""Generate results given a list of inputs. """Generate results given a list of inputs.
...@@ -232,6 +233,7 @@ class HuggingFace(BaseModel): ...@@ -232,6 +233,7 @@ class HuggingFace(BaseModel):
Args: Args:
inputs (List[str]): A list of strings. inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output. max_out_len (int): The maximum length of the output.
min_out_len (Optional[int]): The minimum length of the output.
Returns: Returns:
List[str]: A list of generated strings. List[str]: A list of generated strings.
...@@ -241,12 +243,14 @@ class HuggingFace(BaseModel): ...@@ -241,12 +243,14 @@ class HuggingFace(BaseModel):
if self.batch_padding and len(inputs) > 1: if self.batch_padding and len(inputs) > 1:
return self._batch_generate(inputs=inputs, return self._batch_generate(inputs=inputs,
max_out_len=max_out_len, max_out_len=max_out_len,
min_out_len=min_out_len,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
**generation_kwargs) **generation_kwargs)
else: else:
return sum( return sum(
(self._single_generate(inputs=[input_], (self._single_generate(inputs=[input_],
max_out_len=max_out_len, max_out_len=max_out_len,
min_out_len=min_out_len,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
**generation_kwargs) **generation_kwargs)
for input_ in inputs), []) for input_ in inputs), [])
...@@ -254,6 +258,7 @@ class HuggingFace(BaseModel): ...@@ -254,6 +258,7 @@ class HuggingFace(BaseModel):
def _batch_generate(self, def _batch_generate(self,
inputs: List[str], inputs: List[str],
max_out_len: int, max_out_len: int,
min_out_len: Optional[int] = None,
stopping_criteria: List[str] = [], stopping_criteria: List[str] = [],
**kwargs) -> List[str]: **kwargs) -> List[str]:
"""Support for batch prompts inference. """Support for batch prompts inference.
...@@ -308,6 +313,9 @@ class HuggingFace(BaseModel): ...@@ -308,6 +313,9 @@ class HuggingFace(BaseModel):
]) ])
kwargs['stopping_criteria'] = stopping_criteria kwargs['stopping_criteria'] = stopping_criteria
if min_out_len is not None:
kwargs['min_new_tokens'] = min_out_len
# step-2: conduct model forward to generate output # step-2: conduct model forward to generate output
outputs = self.model.generate(**tokens, outputs = self.model.generate(**tokens,
max_new_tokens=max_out_len, max_new_tokens=max_out_len,
...@@ -331,6 +339,7 @@ class HuggingFace(BaseModel): ...@@ -331,6 +339,7 @@ class HuggingFace(BaseModel):
def _single_generate(self, def _single_generate(self,
inputs: List[str], inputs: List[str],
max_out_len: int, max_out_len: int,
min_out_len: Optional[int] = None,
stopping_criteria: List[str] = [], stopping_criteria: List[str] = [],
**kwargs) -> List[str]: **kwargs) -> List[str]:
"""Support for single prompt inference. """Support for single prompt inference.
...@@ -390,6 +399,9 @@ class HuggingFace(BaseModel): ...@@ -390,6 +399,9 @@ class HuggingFace(BaseModel):
]) ])
kwargs['stopping_criteria'] = stopping_criteria kwargs['stopping_criteria'] = stopping_criteria
if min_out_len is not None:
kwargs['min_new_tokens'] = min_out_len
# To accommodate the PeftModel, parameters should be passed in # To accommodate the PeftModel, parameters should be passed in
# key-value format for generate. # key-value format for generate.
outputs = self.model.generate(input_ids=input_ids, outputs = self.model.generate(input_ids=input_ids,
...@@ -502,7 +514,7 @@ class HuggingFace(BaseModel): ...@@ -502,7 +514,7 @@ class HuggingFace(BaseModel):
self.tokenizer.pad_token_id).sum(-1).cpu().numpy() self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
if mask_length is not None: if mask_length is not None:
lens -= np.array(mask_length) lens -= np.array(mask_length)
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens ce_loss = loss.float().sum(-1).cpu().detach().numpy() / lens
return ce_loss return ce_loss
def get_loglikelihood( def get_loglikelihood(
...@@ -554,7 +566,6 @@ class HuggingFace(BaseModel): ...@@ -554,7 +566,6 @@ class HuggingFace(BaseModel):
input_ids = input_tokenizer_out['input_ids'][:, :self.max_seq_len] input_ids = input_tokenizer_out['input_ids'][:, :self.max_seq_len]
input_length = input_tokenizer_out['length'] input_length = input_tokenizer_out['length']
attention_mask = input_tokenizer_out['attention_mask']
context_ids = [ context_ids = [
self.tokenizer(inputs[i].replace(conts[i], ''), self.tokenizer(inputs[i].replace(conts[i], ''),
padding=False, padding=False,
...@@ -563,7 +574,7 @@ class HuggingFace(BaseModel): ...@@ -563,7 +574,7 @@ class HuggingFace(BaseModel):
for i in range(len(inputs)) for i in range(len(inputs))
] ]
# forward # forward
outputs = self.model(input_ids, attention_mask)['logits'] outputs = self.model(input_ids)['logits']
outputs = torch.nn.functional.log_softmax(outputs, dim=-1) outputs = torch.nn.functional.log_softmax(outputs, dim=-1)
# calculate loglikelihood # calculate loglikelihood
answer = np.zeros(len(inputs)) answer = np.zeros(len(inputs))
...@@ -609,9 +620,10 @@ class HuggingFace(BaseModel): ...@@ -609,9 +620,10 @@ class HuggingFace(BaseModel):
self.tokenizer.pad_token_id).sum(-1).cpu().numpy() self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
mink_percent = [] mink_percent = []
for nloss, nlen in zip(loss, lens): for nloss, nlen in zip(loss, lens):
nlen = max(int(nlen) * k // 100, 1) nlen = int(nlen)
nloss = torch.topk(loss, nlen, dim=-1)[0] minklen = max(nlen * k // 100, 1)
nloss = -nloss.mean().cpu().detach().numpy() nloss = torch.topk(loss[-nlen:], minklen, dim=-1)[0]
nloss = -nloss.float().mean().cpu().detach().numpy()
mink_percent.append(nloss) mink_percent.append(nloss)
return np.array(mink_percent) return np.array(mink_percent)
......
...@@ -29,6 +29,8 @@ class GenInferencer(BaseInferencer): ...@@ -29,6 +29,8 @@ class GenInferencer(BaseInferencer):
model (:obj:`BaseModelWrapper`, optional): The module to inference. model (:obj:`BaseModelWrapper`, optional): The module to inference.
max_seq_len (:obj:`int`, optional): Maximum number of tokenized words max_seq_len (:obj:`int`, optional): Maximum number of tokenized words
allowed by the LM. allowed by the LM.
min_out_len (:obj:`int`, optional): Minimum number of generated tokens
by the LM
batch_size (:obj:`int`, optional): Batch size for the batch_size (:obj:`int`, optional): Batch size for the
:obj:`DataLoader`. :obj:`DataLoader`.
output_json_filepath (:obj:`str`, optional): File path for output output_json_filepath (:obj:`str`, optional): File path for output
...@@ -49,6 +51,7 @@ class GenInferencer(BaseInferencer): ...@@ -49,6 +51,7 @@ class GenInferencer(BaseInferencer):
max_out_len: int, max_out_len: int,
stopping_criteria: List[str] = [], stopping_criteria: List[str] = [],
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
min_out_len: Optional[int] = None,
batch_size: Optional[int] = 1, batch_size: Optional[int] = 1,
gen_field_replace_token: Optional[str] = '', gen_field_replace_token: Optional[str] = '',
output_json_filepath: Optional[str] = './icl_inference_output', output_json_filepath: Optional[str] = './icl_inference_output',
...@@ -66,6 +69,7 @@ class GenInferencer(BaseInferencer): ...@@ -66,6 +69,7 @@ class GenInferencer(BaseInferencer):
self.gen_field_replace_token = gen_field_replace_token self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len self.max_out_len = max_out_len
self.min_out_len = min_out_len
self.stopping_criteria = stopping_criteria self.stopping_criteria = stopping_criteria
if self.model.is_api and save_every is None: if self.model.is_api and save_every is None:
...@@ -135,6 +139,8 @@ class GenInferencer(BaseInferencer): ...@@ -135,6 +139,8 @@ class GenInferencer(BaseInferencer):
sig = inspect.signature(self.model.generate) sig = inspect.signature(self.model.generate)
if 'stopping_criteria' in sig.parameters: if 'stopping_criteria' in sig.parameters:
extra_gen_kwargs['stopping_criteria'] = self.stopping_criteria extra_gen_kwargs['stopping_criteria'] = self.stopping_criteria
if 'min_out_len' in sig.parameters:
extra_gen_kwargs['min_out_len'] = self.min_out_len
with torch.no_grad(): with torch.no_grad():
parsed_entries = self.model.parse_template(entry, mode='gen') parsed_entries = self.model.parse_template(entry, mode='gen')
results = self.model.generate_from_template( results = self.model.generate_from_template(
......
...@@ -116,7 +116,7 @@ class DLCRunner(BaseRunner): ...@@ -116,7 +116,7 @@ class DLCRunner(BaseRunner):
' --worker_count 1' ' --worker_count 1'
f' --worker_cpu {max(num_gpus * 6, 8)}' f' --worker_cpu {max(num_gpus * 6, 8)}'
f' --worker_gpu {num_gpus}' f' --worker_gpu {num_gpus}'
f' --worker_memory {max(num_gpus * 32, 48)}' f' --worker_memory {max(num_gpus * 64, 48)}'
f" --worker_image {self.aliyun_cfg['worker_image']}" f" --worker_image {self.aliyun_cfg['worker_image']}"
' --interactive') ' --interactive')
get_cmd = partial(task.get_command, get_cmd = partial(task.get_command,
......
...@@ -61,6 +61,7 @@ class OpenICLInferTask(BaseTask): ...@@ -61,6 +61,7 @@ class OpenICLInferTask(BaseTask):
for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs): for model_cfg, dataset_cfgs in zip(self.model_cfgs, self.dataset_cfgs):
self.max_out_len = model_cfg.get('max_out_len', None) self.max_out_len = model_cfg.get('max_out_len', None)
self.batch_size = model_cfg.get('batch_size', None) self.batch_size = model_cfg.get('batch_size', None)
self.min_out_len = model_cfg.get('min_out_len', None)
self.model = build_model_from_cfg(model_cfg) self.model = build_model_from_cfg(model_cfg)
for dataset_cfg in dataset_cfgs: for dataset_cfg in dataset_cfgs:
...@@ -102,6 +103,8 @@ class OpenICLInferTask(BaseTask): ...@@ -102,6 +103,8 @@ class OpenICLInferTask(BaseTask):
inferencer_cfg['model'] = self.model inferencer_cfg['model'] = self.model
self._set_default_value(inferencer_cfg, 'max_out_len', self._set_default_value(inferencer_cfg, 'max_out_len',
self.max_out_len) self.max_out_len)
self._set_default_value(inferencer_cfg, 'min_out_len',
self.min_out_len)
self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size) self._set_default_value(inferencer_cfg, 'batch_size', self.batch_size)
inferencer_cfg['max_seq_len'] = self.model_cfg.get('max_seq_len') inferencer_cfg['max_seq_len'] = self.model_cfg.get('max_seq_len')
inferencer = ICL_INFERENCERS.build(inferencer_cfg) inferencer = ICL_INFERENCERS.build(inferencer_cfg)
......
...@@ -21,4 +21,5 @@ def build_model_from_cfg(model_cfg: ConfigDict): ...@@ -21,4 +21,5 @@ def build_model_from_cfg(model_cfg: ConfigDict):
model_cfg.pop('abbr', None) model_cfg.pop('abbr', None)
model_cfg.pop('summarizer_abbr', None) model_cfg.pop('summarizer_abbr', None)
model_cfg.pop('pred_postprocessor', None) model_cfg.pop('pred_postprocessor', None)
model_cfg.pop('min_out_len', None)
return MODELS.build(model_cfg) return MODELS.build(model_cfg)
...@@ -5,7 +5,8 @@ from typing import Dict ...@@ -5,7 +5,8 @@ from typing import Dict
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
from opencompass.openicl.icl_inferencer import (CLPInferencer, GenInferencer, from opencompass.openicl.icl_inferencer import (CLPInferencer, GenInferencer,
PPLInferencer) PPLInferencer,
PPLOnlyInferencer)
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS
from opencompass.utils import (Menu, build_dataset_from_cfg, from opencompass.utils import (Menu, build_dataset_from_cfg,
build_model_from_cfg, dataset_abbr_from_cfg, build_model_from_cfg, dataset_abbr_from_cfg,
...@@ -77,7 +78,8 @@ def print_prompts(model_cfg, dataset_cfg, count=1): ...@@ -77,7 +78,8 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
ice_idx_list = retriever.retrieve() ice_idx_list = retriever.retrieve()
assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \ assert infer_cfg.inferencer.type in [
PPLInferencer, GenInferencer, CLPInferencer, PPLOnlyInferencer], \
'Only PPLInferencer and GenInferencer are supported' 'Only PPLInferencer and GenInferencer are supported'
for idx in range(min(count, len(ice_idx_list))): for idx in range(min(count, len(ice_idx_list))):
...@@ -127,7 +129,9 @@ def print_prompts(model_cfg, dataset_cfg, count=1): ...@@ -127,7 +129,9 @@ def print_prompts(model_cfg, dataset_cfg, count=1):
print('-' * 100) print('-' * 100)
print(prompt) print(prompt)
print('-' * 100) print('-' * 100)
elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]: elif infer_cfg.inferencer.type in [
GenInferencer, CLPInferencer, PPLOnlyInferencer
]:
ice_idx = ice_idx_list[idx] ice_idx = ice_idx_list[idx]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template) ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task( prompt = retriever.generate_prompt_for_generate_task(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment