Unverified Commit 86d5ec3d authored by Leymore's avatar Leymore Committed by GitHub
Browse files

Update configs (#9)

* Update implements

* Update
parent 2d0b184b
...@@ -17,11 +17,10 @@ class TheoremQADataset(BaseDataset): ...@@ -17,11 +17,10 @@ class TheoremQADataset(BaseDataset):
@TEXT_POSTPROCESSORS.register_module('TheoremQA') @TEXT_POSTPROCESSORS.register_module('TheoremQA')
def TheoremQA_postprocess(text: str) -> str: def TheoremQA_postprocess(text: str) -> str:
text = text.strip()
text = text.strip().split('\n')[0].strip() matches = re.findall(r'answer is ([^\s]+)', text)
matches = re.findall(r'answer is (.*)', text)
if len(matches) == 0: if len(matches) == 0:
return text return text
else: else:
text = matches[0].strip()[:-1] text = matches[0].strip().strip('.,?!\"\';:')
return text return text
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
from datasets import Dataset from datasets import Dataset
from opencompass.registry import LOAD_DATASET from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from .base import BaseDataset from .base import BaseDataset
...@@ -38,3 +38,10 @@ class CMRCDataset(BaseDataset): ...@@ -38,3 +38,10 @@ class CMRCDataset(BaseDataset):
}) })
return dataset return dataset
@TEXT_POSTPROCESSORS.register_module('cmrc')
def cmrc_postprocess(text: str) -> str:
if '答案是' in text:
text = text.split('答案是')[1]
return text
...@@ -2,7 +2,7 @@ import json ...@@ -2,7 +2,7 @@ import json
from datasets import Dataset from datasets import Dataset
from opencompass.registry import LOAD_DATASET from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from .base import BaseDataset from .base import BaseDataset
...@@ -38,3 +38,10 @@ class DRCDDataset(BaseDataset): ...@@ -38,3 +38,10 @@ class DRCDDataset(BaseDataset):
}) })
return dataset return dataset
@TEXT_POSTPROCESSORS.register_module('drcd')
def drcd_postprocess(text: str) -> str:
if '答案是' in text:
text = text.split('答案是')[1]
return text
...@@ -33,12 +33,11 @@ class HumanEvaluator(BaseEvaluator): ...@@ -33,12 +33,11 @@ class HumanEvaluator(BaseEvaluator):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
out_dir = osp.join(tmp_dir, 'human_eval.json') out_dir = osp.join(tmp_dir, 'human_eval.json')
self.write_jsonl(out_dir, predictions) self.write_jsonl(out_dir, predictions)
score = self.eval( score = self.eval(out_dir,
out_dir, self.k,
self.k, n_workers=4,
n_workers=4, timeout=3.0,
timeout=3.0, problem_file=self.HUMAN_EVAL)
problem_file=self.HUMAN_EVAL)
return {f'humaneval_{k}': score[k] * 100 for k in score} return {f'humaneval_{k}': score[k] * 100 for k in score}
...@@ -47,7 +46,7 @@ def humaneval_postprocess(text: str) -> str: ...@@ -47,7 +46,7 @@ def humaneval_postprocess(text: str) -> str:
text = text.split('\n\n')[0] text = text.split('\n\n')[0]
if '```' in text: if '```' in text:
text = text.split('```')[1] text = text.split('```')[1]
if text.startswith('def'): if text.strip().startswith('def'):
text = '\n'.join(text.split('\n')[1:]) text = '\n'.join(text.split('\n')[1:])
if not text.startswith(' '): if not text.startswith(' '):
if text.startswith(' '): if text.startswith(' '):
......
...@@ -3,7 +3,8 @@ import json ...@@ -3,7 +3,8 @@ import json
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
from opencompass.openicl.icl_evaluator import BaseEvaluator from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, TEXT_POSTPROCESSORS from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET,
TEXT_POSTPROCESSORS)
from .base import BaseDataset from .base import BaseDataset
...@@ -65,12 +66,12 @@ class MATHDataset(BaseDataset): ...@@ -65,12 +66,12 @@ class MATHDataset(BaseDataset):
return dataset return dataset
@TEXT_POSTPROCESSORS.register_module('math') @TEXT_POSTPROCESSORS.register_module('math_postprocess')
def math_postprocess(text: str) -> str: def math_postprocess(text: str) -> str:
SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''), SUBSTITUTIONS = [('an ', ''), ('a ', ''), ('.$', '$'), ('\\$', ''),
(r'\ ', ''), (' ', ''), ('mbox', 'text'), (r'\ ', ''), (' ', ''), ('mbox', 'text'),
(',\\text{and}', ','), ('\\text{and}', ','), (',\\text{and}', ','), ('\\text{and}', ','),
('\\text{m}', '\\text{}'), ('\le', '<')] ('\\text{m}', '\\text{}'), ('\\le', '<')]
REMOVED_EXPRESSIONS = [ REMOVED_EXPRESSIONS = [
'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft', 'square', 'ways', 'integers', 'dollars', 'mph', 'inches', 'ft',
'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet', 'minutes', 'hours', 'km', 'units', '\\ldots', 'sue', 'points', 'feet', 'minutes',
...@@ -96,7 +97,9 @@ def math_postprocess(text: str) -> str: ...@@ -96,7 +97,9 @@ def math_postprocess(text: str) -> str:
final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer) final_answer = re.sub(r'(\\textbf\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer) final_answer = re.sub(r'(\\overline\{)(.*?)(\})', '\\2', final_answer)
final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer) final_answer = re.sub(r'(\\boxed\{)(.*)(\})', '\\2', final_answer)
assert '\n' not in final_answer and '\r' not in final_answer and '\f' not in final_answer assert '\n' not in final_answer
assert '\r' not in final_answer
assert '\f' not in final_answer
if len(re.findall(r'finalansweris(.*)', final_answer)) > 0: if len(re.findall(r'finalansweris(.*)', final_answer)) > 0:
final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1] final_answer = re.findall(r'finalansweris(.*)', final_answer)[-1]
......
...@@ -20,13 +20,13 @@ class MBPPDataset(BaseDataset): ...@@ -20,13 +20,13 @@ class MBPPDataset(BaseDataset):
def processing_test(example): def processing_test(example):
example['test_case'] = example['test_list'] example['test_case'] = example['test_list']
example['test_list'] = '\n'.join(example['test_list']) example['test_list'] = '\n'.join(example['test_list'])
example['test_list_2'] = example['test_list']
return example return example
train = load_dataset( train = load_dataset('json', data_files=path,
'json', data_files=path, split='train[:10]').map(processing_test) split='train[:10]').map(processing_test)
test = load_dataset( test = load_dataset('json', data_files=path,
'json', data_files=path, split='train[10:510]').map(processing_test)
split='train[10:510]').map(processing_test)
return DatasetDict({'train': train, 'test': test}) return DatasetDict({'train': train, 'test': test})
......
import csv import csv
import os.path as osp import os.path as osp
import re
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
...@@ -43,16 +42,19 @@ class NQEvaluator(BaseEvaluator): ...@@ -43,16 +42,19 @@ class NQEvaluator(BaseEvaluator):
'error': 'predictions and references have different ' 'error': 'predictions and references have different '
'length' 'length'
} }
predictions = [ processed_predictions = []
re.split(r'[\n]', prediction, 1)[0].lower() for prediction in predictions:
for prediction in predictions prediction = prediction.split('\n')[0].lower()
] if 'answer is' in prediction:
prediction = prediction.split('answer is')[-1]
prediction = general_postprocess(prediction)
processed_predictions.append(prediction)
processed_answers = [[general_postprocess(j).lower() for j in i] processed_answers = [[general_postprocess(j).lower() for j in i]
for i in references] for i in references]
cnt = 0 cnt = 0
for pred, cand_ans in zip(predictions, processed_answers): for pred, cand_ans in zip(processed_predictions, processed_answers):
cnt += int(any([cand in pred for cand in cand_ans])) cnt += int(any([cand == pred for cand in cand_ans]))
score = cnt / len(predictions) * 100 score = cnt / len(predictions) * 100
return {'score': score} return {'score': score}
from datasets import load_dataset from datasets import Dataset, DatasetDict, load_dataset
from opencompass.registry import LOAD_DATASET from opencompass.registry import LOAD_DATASET
...@@ -11,7 +11,17 @@ class RealToxicPromptsDataset(BaseDataset): ...@@ -11,7 +11,17 @@ class RealToxicPromptsDataset(BaseDataset):
@staticmethod @staticmethod
def load(**kwargs): def load(**kwargs):
challenging_subset = kwargs.pop('challenging_subset', False) challenging_subset = kwargs.pop('challenging_subset', False)
dataset = load_dataset(**kwargs) if kwargs['path'] == 'allenai/real-toxicity-prompts':
try:
dataset = load_dataset(**kwargs)
except ConnectionError as e:
raise ConnectionError(
f'{e} Something wrong with this dataset, '
'cannot track it online or use offline mode, '
'please set local file path directly.')
else:
dataset = Dataset.from_file(kwargs.pop('path'))
dataset = DatasetDict(train=dataset)
def preprocess(example): def preprocess(example):
......
import re
from opencompass.registry import TEXT_POSTPROCESSORS from opencompass.registry import TEXT_POSTPROCESSORS
@TEXT_POSTPROCESSORS.register_module('strategyqa') @TEXT_POSTPROCESSORS.register_module('strategyqa')
def strategyqa_pred_postprocess(text: str) -> str: def strategyqa_pred_postprocess(text: str) -> str:
text = text.split('\n\n')[0] text = text.split('\n\n')[0]
strategyqa_pre = text.split('So the answer is ')[-1].strip().replace( text = text.split('answer is ')[-1]
'.', '') match = re.search(r'(yes|no)', text.lower())
return strategyqa_pre if match:
return match.group(1)
return ''
@TEXT_POSTPROCESSORS.register_module('strategyqa_dataset') @TEXT_POSTPROCESSORS.register_module('strategyqa_dataset')
......
import csv import csv
import os.path as osp import os.path as osp
import re
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
...@@ -36,25 +35,25 @@ class TriviaQADataset(BaseDataset): ...@@ -36,25 +35,25 @@ class TriviaQADataset(BaseDataset):
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
class TriviaQAEvaluator(BaseEvaluator): class TriviaQAEvaluator(BaseEvaluator):
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references): def score(self, predictions, references):
if len(predictions) != len(references): if len(predictions) != len(references):
return { return {
'error': 'predictions and references have different ' 'error': 'predictions and references have different '
'length' 'length'
} }
predictions = [ processed_predictions = []
re.split(r'[\n]', prediction, 1)[0].lower() for prediction in predictions:
for prediction in predictions prediction = prediction.split('\n')[0].lower()
] if 'answer is' in prediction:
prediction = prediction.split('answer is')[-1]
prediction = general_postprocess(prediction)
processed_predictions.append(prediction)
processed_answers = [[general_postprocess(j).lower() for j in i] processed_answers = [[general_postprocess(j).lower() for j in i]
for i in references] for i in references]
cnt = 0 cnt = 0
for pred, cand_ans in zip(predictions, processed_answers): for pred, cand_ans in zip(processed_predictions, processed_answers):
cnt += int(any([cand in pred for cand in cand_ans])) cnt += int(any([cand == pred for cand in cand_ans]))
score = cnt / len(predictions) * 100 score = cnt / len(predictions) * 100
return {'score': score} return {'score': score}
...@@ -56,3 +56,47 @@ class WSCDataset_V2(BaseDataset): ...@@ -56,3 +56,47 @@ class WSCDataset_V2(BaseDataset):
} }
data.append(item) data.append(item)
return Dataset.from_list(data) return Dataset.from_list(data)
@LOAD_DATASET.register_module()
class WSCDataset_V3(BaseDataset):
@staticmethod
def load(path):
data = []
with open(path, 'r') as f:
for line in f:
line = json.loads(line)
text_list = line['text'].split(' ')
span_text1_len = len(line['target']['span1_text'].split(' '))
span_text2_len = len(line['target']['span2_text'].split(' '))
span1_start = line['target']['span1_index']
span1_end = span1_start + span_text1_len
span2_start = line['target']['span2_index']
span2_end = span2_start + span_text2_len
new_text_list = []
for i, t in enumerate(text_list):
if span1_start <= i < span1_end:
if i == span1_start:
new_text_list.append('* ' +
line['target']['span1_text'] +
' *')
elif span2_start <= i < span2_end:
if i == span2_start:
new_text_list.append('# ' +
line['target']['span2_text'] +
' #')
else:
new_text_list.append(t)
item = {
'span1': line['target']['span1_text'],
'span2': line['target']['span2_text'],
'text': ' '.join(new_text_list),
'label': {
'true': 'A',
'false': 'B'
}[line['label']],
}
data.append(item)
return Dataset.from_list(data)
import os
import json import json
import os
import threading import threading
import time
from typing import List from typing import List
import numpy as np import numpy as np
...@@ -44,11 +45,11 @@ class PerspectiveAPIClient: ...@@ -44,11 +45,11 @@ class PerspectiveAPIClient:
"""Build resource when request in real.""" """Build resource when request in real."""
if self.client is None: if self.client is None:
self.client = discovery.build( self.client = discovery.build(
"commentanalyzer", 'commentanalyzer',
"v1alpha1", 'v1alpha1',
developerKey=self.key, developerKey=self.key,
discoveryServiceUrl= # noqa discoveryServiceUrl= # noqa
"https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1", # noqa 'https://commentanalyzer.googleapis.com/$discovery/rest?version=v1alpha1', # noqa
static_discovery=False, static_discovery=False,
) )
...@@ -63,13 +64,13 @@ class PerspectiveAPIClient: ...@@ -63,13 +64,13 @@ class PerspectiveAPIClient:
dict: Request body dictionary. dict: Request body dictionary.
""" """
return { return {
"comment": { 'comment': {
"text": text 'text': text
}, },
"requestedAttributes": { 'requestedAttributes': {
'TOXICITY': {} 'TOXICITY': {}
}, },
"languages": ["en", "zh"], # This does not affect the result score 'languages': ['en', 'zh'], # This does not affect the result score
} }
def extract_toxicity_attributes(self, response: dict) -> dict: def extract_toxicity_attributes(self, response: dict) -> dict:
...@@ -83,19 +84,19 @@ class PerspectiveAPIClient: ...@@ -83,19 +84,19 @@ class PerspectiveAPIClient:
""" """
if response: if response:
all_scores = { all_scores = {
f"{attribute.lower()}_score": f'{attribute.lower()}_score':
scores["spanScores"][0]["score"]["value"] scores['spanScores'][0]['score']['value']
for attribute, scores in response["attributeScores"].items() for attribute, scores in response['attributeScores'].items()
} }
else: else:
all_scores = {'toxicity_score': -1} all_scores = {'toxicity_score': -1}
return all_scores return all_scores
def get_toxicity_scores(self, predicitons: List) -> dict: def get_toxicity_scores(self, predictions: List) -> dict:
"""Request toxicity scores for PerspectiveAPI. """Request toxicity scores for PerspectiveAPI.
Args: Args:
predicitons: Text to perform evaluation. predictions: Text to perform evaluation.
Returns: Returns:
dict: output scores dictionary for all samples. dict: output scores dictionary for all samples.
...@@ -104,12 +105,13 @@ class PerspectiveAPIClient: ...@@ -104,12 +105,13 @@ class PerspectiveAPIClient:
batch_results = dict() batch_results = dict()
pbar = ProgressBar(len(predicitons)) pbar = ProgressBar(len(predictions))
def callback(request_id: str, response: dict, error): def callback(request_id: str, response: dict, error):
if error: if error:
content = json.loads(error.args[1]) content = json.loads(error.args[1])
error_type = content['error']['details'][0]['errorType'] error_type = content['error']['details'][0].get(
'errorType', '')
if error_type != 'COMMENT_EMPTY': if error_type != 'COMMENT_EMPTY':
raise error raise error
batch_results[request_id] = response batch_results[request_id] = response
...@@ -123,7 +125,7 @@ class PerspectiveAPIClient: ...@@ -123,7 +125,7 @@ class PerspectiveAPIClient:
# Add individual request to the batch request. Deduplicate since we # Add individual request to the batch request. Deduplicate since we
# use the text as request keys. # use the text as request keys.
for idx, text in enumerate(predicitons): for idx, text in enumerate(predictions):
batch_request.add( batch_request.add(
request=self.client.comments().analyze( request=self.client.comments().analyze(
body=PerspectiveAPIClient.create_request_body( body=PerspectiveAPIClient.create_request_body(
...@@ -134,6 +136,7 @@ class PerspectiveAPIClient: ...@@ -134,6 +136,7 @@ class PerspectiveAPIClient:
if (idx + 1) % bs == 0: if (idx + 1) % bs == 0:
batch_request.execute() batch_request.execute()
time.sleep(1)
batch_request = self.client.new_batch_http_request() batch_request = self.client.new_batch_http_request()
with self.request_lock: with self.request_lock:
...@@ -198,11 +201,10 @@ class ToxicEvaluator(BaseEvaluator): ...@@ -198,11 +201,10 @@ class ToxicEvaluator(BaseEvaluator):
length = np.count_nonzero(~np.isnan(toxicity_scores)) length = np.count_nonzero(~np.isnan(toxicity_scores))
max_toxicity_score = max(toxicity_scores) max_toxicity_score = max(toxicity_scores)
return dict( return dict(expected_max_toxicity=round(max_toxicity_score, 4),
expected_max_toxicity=round(max_toxicity_score, 4), max_toxicity_probability=max_toxicity_score >= self.thr,
max_toxicity_probability=max_toxicity_score >= self.thr, toxic_frac=round(num_toxic_completions / length, 4),
toxic_frac=round(num_toxic_completions / length, 4), avg_toxicity_score=round(np.nanmean(toxicity_scores), 4))
avg_toxicity_score=round(np.nanmean(toxicity_scores), 4))
def score(self, predictions: List, references: List) -> dict: def score(self, predictions: List, references: List) -> dict:
"""Calculate scores. Reference is not needed. """Calculate scores. Reference is not needed.
......
...@@ -93,7 +93,7 @@ class CLPInferencer: ...@@ -93,7 +93,7 @@ class CLPInferencer:
output_json_filename: Optional[str] = None, output_json_filename: Optional[str] = None,
normalizing_str: Optional[str] = None) -> List: normalizing_str: Optional[str] = None) -> List:
# 1. Preparation for output logs # 1. Preparation for output logs
output_handler = PPLInferencerOutputHandler(self.accelerator) output_handler = PPLInferencerOutputHandler()
ice = [] ice = []
...@@ -122,9 +122,17 @@ class CLPInferencer: ...@@ -122,9 +122,17 @@ class CLPInferencer:
choice_target_ids = [] choice_target_ids = []
# TODO: Hard code temperaily, need to modified here # TODO: Hard code temperaily, need to modified here
choices = retriever.test_ds[0]['choices'] choices = retriever.test_ds[0]['choices']
choice_ids = [ try:
self.model.tokenizer.encode(c, False, False) for c in choices choice_ids = [
] self.model.tokenizer.encode(c, False, False)
for c in choices
]
except ValueError:
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
if self.model.tokenizer.add_bos_token:
choice_ids = [c[1:] for c in choice_ids]
if self.model.tokenizer.add_eos_token:
choice_ids = [c[:-1] for c in choice_ids]
if isinstance(choice_ids[0], list): if isinstance(choice_ids[0], list):
# in case tokenizer returns list for single token # in case tokenizer returns list for single token
choice_ids = list(itertools.chain(*choice_ids)) choice_ids = list(itertools.chain(*choice_ids))
...@@ -185,15 +193,10 @@ class CLPInferencer: ...@@ -185,15 +193,10 @@ class CLPInferencer:
index = index + 1 index = index + 1
# 5. Output # 5. Output
os.makedirs(output_json_filepath, exist_ok=True) if self.is_main_process:
output_handler.subprocess_write_to_json(output_json_filepath, os.makedirs(output_json_filepath, exist_ok=True)
output_json_filename) output_handler.write_to_json(output_json_filepath,
if self.accelerator is not None: output_json_filename)
self.accelerator.wait_for_everyone()
output_handler.merge_to_main_process(output_json_filepath,
output_json_filename)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
return [ return [
sample['prediction'] sample['prediction']
...@@ -206,8 +209,10 @@ class CLPInferencer: ...@@ -206,8 +209,10 @@ class CLPInferencer:
choice_ids, choice_ids,
mask_length=None): mask_length=None):
# TODO: support multiple tokens # TODO: support multiple tokens
outputs, _ = self.model.generator.get_logits(input_texts) try:
outputs, _ = self.model.generator.get_logits(input_texts)
except AttributeError:
outputs, _ = self.model.get_logits(input_texts)
shift_logits = outputs[..., :-1, :].contiguous() shift_logits = outputs[..., :-1, :].contiguous()
shift_logits = F.log_softmax(shift_logits, dim=-1) shift_logits = F.log_softmax(shift_logits, dim=-1)
......
...@@ -111,7 +111,6 @@ class DLCRunner(BaseRunner): ...@@ -111,7 +111,6 @@ class DLCRunner(BaseRunner):
f' --worker_gpu {num_gpus}' f' --worker_gpu {num_gpus}'
f' --worker_memory {max(num_gpus * 32, 48)}' f' --worker_memory {max(num_gpus * 32, 48)}'
f" --worker_image {self.aliyun_cfg['worker_image']}" f" --worker_image {self.aliyun_cfg['worker_image']}"
' --priority 3'
' --interactive') ' --interactive')
logger = get_logger() logger = get_logger()
......
...@@ -13,7 +13,7 @@ from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg, ...@@ -13,7 +13,7 @@ from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg,
model_abbr_from_cfg) model_abbr_from_cfg)
from opencompass.utils.prompt import get_prompt_hash from opencompass.utils.prompt import get_prompt_hash
METRIC_WHITELIST = ['score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth'] METRIC_WHITELIST = ['score', 'auc_score', 'accuracy', 'humaneval_pass@1', 'rouge1', 'avg_toxicity_score', 'bleurt_diff', 'matthews_correlation', 'truth']
METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len'] METRIC_BLACKLIST = ['bp', 'sys_len', 'ref_len']
class Summarizer: class Summarizer:
......
...@@ -10,7 +10,6 @@ jieba ...@@ -10,7 +10,6 @@ jieba
mmengine>0.8.0 mmengine>0.8.0
nltk==3.8 nltk==3.8
numpy==1.23.4 numpy==1.23.4
openai==0.27.1
openai openai
pandas<2.0.0 pandas<2.0.0
rank_bm25==0.2.2 rank_bm25==0.2.2
...@@ -18,7 +17,6 @@ requests==2.28.1 ...@@ -18,7 +17,6 @@ requests==2.28.1
scikit_learn==1.2.1 scikit_learn==1.2.1
sentence_transformers==2.2.2 sentence_transformers==2.2.2
tabulate tabulate
tabulate
tiktoken tiktoken
tokenizers>=0.13.3 tokenizers>=0.13.3
torch>=1.13.1 torch>=1.13.1
......
...@@ -149,9 +149,12 @@ def main(): ...@@ -149,9 +149,12 @@ def main():
cfg_time_str = dir_time_str = datetime.now().strftime('%Y%m%d_%H%M%S') cfg_time_str = dir_time_str = datetime.now().strftime('%Y%m%d_%H%M%S')
if args.reuse: if args.reuse:
if args.reuse == 'latest': if args.reuse == 'latest':
dirs = os.listdir(cfg.work_dir) if not os.path.exists(cfg.work_dir) or not os.listdir(
assert len(dirs) > 0, 'No previous results to reuse!' cfg.work_dir):
dir_time_str = sorted(dirs)[-1] logger.warning('No previous results to reuse!')
else:
dirs = os.listdir(cfg.work_dir)
dir_time_str = sorted(dirs)[-1]
else: else:
dir_time_str = args.reuse dir_time_str = args.reuse
logger.info(f'Reusing experiements from {dir_time_str}') logger.info(f'Reusing experiements from {dir_time_str}')
......
...@@ -4,7 +4,8 @@ from typing import Dict ...@@ -4,7 +4,8 @@ from typing import Dict
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
from opencompass.openicl.icl_inferencer import GenInferencer, PPLInferencer from opencompass.openicl.icl_inferencer import (CLPInferencer, GenInferencer,
PPLInferencer)
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS
from opencompass.utils import (Menu, build_dataset_from_cfg, from opencompass.utils import (Menu, build_dataset_from_cfg,
build_model_from_cfg, dataset_abbr_from_cfg, build_model_from_cfg, dataset_abbr_from_cfg,
...@@ -126,7 +127,7 @@ def print_prompts(model_cfg, dataset_cfg): ...@@ -126,7 +127,7 @@ def print_prompts(model_cfg, dataset_cfg):
print('-' * 100) print('-' * 100)
print(prompt) print(prompt)
print('-' * 100) print('-' * 100)
elif infer_cfg.inferencer.type == GenInferencer: elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]:
idx, ice_idx = 0, ice_idx_list[0] idx, ice_idx = 0, ice_idx_list[0]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template) ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task( prompt = retriever.generate_prompt_for_generate_task(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment