Unverified Commit dbb20b82 authored by Fengzhe Zhou's avatar Fengzhe Zhou Committed by GitHub
Browse files

[Sync] update (#517)

parent 6f07af30
...@@ -23,7 +23,7 @@ class SubjectiveNaivePartitioner(NaivePartitioner): ...@@ -23,7 +23,7 @@ class SubjectiveNaivePartitioner(NaivePartitioner):
mode: str, mode: str,
out_dir: str, out_dir: str,
model_pairs: Optional[List[Tuple]] = None, model_pairs: Optional[List[Tuple]] = None,
keep_keys: List[str] = ['eval.runner.task.judge_cfg']): keep_keys: Optional[List[str]] = None):
super().__init__(out_dir=out_dir, keep_keys=keep_keys) super().__init__(out_dir=out_dir, keep_keys=keep_keys)
assert mode in ['all', 'one_to_n', 'fixed'] assert mode in ['all', 'one_to_n', 'fixed']
self.mode = mode self.mode = mode
......
...@@ -72,6 +72,7 @@ class DefaultSummarizer: ...@@ -72,6 +72,7 @@ class DefaultSummarizer:
if not osp.exists(filepath): if not osp.exists(filepath):
continue continue
result = mmengine.load(filepath) result = mmengine.load(filepath)
result.pop('details', None)
raw_results[model_abbr][dataset_abbr] = result raw_results[model_abbr][dataset_abbr] = result
if 'error' in result: if 'error' in result:
self.logger.debug(f'error in {model_abbr} {dataset_abbr} {result["error"]}') self.logger.debug(f'error in {model_abbr} {dataset_abbr} {result["error"]}')
......
import argparse import argparse
import copy
import fnmatch import fnmatch
import math
import os.path as osp import os.path as osp
import statistics
import time import time
from collections import Counter from collections import Counter
from inspect import signature from inspect import signature
from shutil import which from shutil import which
from typing import Optional from typing import List, Optional
import mmengine import mmengine
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
...@@ -35,6 +38,8 @@ class OpenICLEvalTask(BaseTask): ...@@ -35,6 +38,8 @@ class OpenICLEvalTask(BaseTask):
super().__init__(cfg) super().__init__(cfg)
self.num_gpus = 0 self.num_gpus = 0
self.logger = get_logger() self.logger = get_logger()
self.dump_details = cfg.get('eval', {}).get('runner', {}).get(
'task', {}).get('dump_details', False)
def get_command(self, cfg_path, template): def get_command(self, cfg_path, template):
script_path = __file__ script_path = __file__
...@@ -113,7 +118,7 @@ class OpenICLEvalTask(BaseTask): ...@@ -113,7 +118,7 @@ class OpenICLEvalTask(BaseTask):
[sub_preds[str(i)] for i in range(len(sub_preds))]) [sub_preds[str(i)] for i in range(len(sub_preds))])
filename = root + f'_{i}' + ext filename = root + f'_{i}' + ext
i += 1 i += 1
pred_dicts = copy.deepcopy(preds)
preds = {k: [pred.get(k) for pred in preds] for k in preds[0]} preds = {k: [pred.get(k) for pred in preds] for k in preds[0]}
pred_strs = preds.pop('prediction') pred_strs = preds.pop('prediction')
...@@ -163,6 +168,7 @@ class OpenICLEvalTask(BaseTask): ...@@ -163,6 +168,7 @@ class OpenICLEvalTask(BaseTask):
] ]
icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator']) icl_evaluator = ICL_EVALUATORS.build(self.eval_cfg['evaluator'])
preds['predictions'] = pred_strs preds['predictions'] = pred_strs
preds['references'] = (test_set[self.output_column] preds['references'] = (test_set[self.output_column]
if self.output_column else None) if self.output_column else None)
...@@ -172,18 +178,42 @@ class OpenICLEvalTask(BaseTask): ...@@ -172,18 +178,42 @@ class OpenICLEvalTask(BaseTask):
} }
result = icl_evaluator.score(**preds) result = icl_evaluator.score(**preds)
if self.dump_details:
try:
details = result.pop('details', None)
result['details'] = self.format_details(
pred_strs, test_set[self.output_column], details,
pred_dicts)
result['type'] = result['details'].pop('type', None)
if 'PPL' in str(
self.dataset_cfg.infer_cfg.inferencer.type):
result['correct_bpb'], result[
'incorrect_bpb'] = self.calculate_bpb(pred_dicts)
else:
result['incorrect_bpb'] = result['correct_bpb'] = -1
except Exception:
result['incorrect_bpb'] = result['correct_bpb'] = -1
else:
result.pop('details', None)
if 'error' in result: if 'error' in result:
self.logger.error( self.logger.error(
f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}') f'Task {task_abbr_from_cfg(self.cfg)}: {result["error"]}')
return return
else: else:
self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}: {result}') result_wo_details = {
i: result[i]
for i in result if i != 'details'
}
self.logger.info(
f'Task {task_abbr_from_cfg(self.cfg)}: {result_wo_details}')
# Save result # Save result
out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg, out_path = get_infer_output_path(self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'results')) osp.join(self.work_dir, 'results'))
mkdir_or_exist(osp.split(out_path)[0]) mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path) mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str], def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str: end_str: Optional[str]) -> str:
...@@ -215,6 +245,95 @@ class OpenICLEvalTask(BaseTask): ...@@ -215,6 +245,95 @@ class OpenICLEvalTask(BaseTask):
return s[start:end] return s[start:end]
def format_details(self, predictions, references, details, pred_dicts):
"""This function is responsible for formatting prediction details.
Args:
predictions (list): The prediction list.
references (list): The reference list.
details (list): Contains the 'pred' 'answer' and 'correct' for each
sample. Such as `[{'pred': '光荣和ωforce',
'answers': ['光荣和ω-force', '光荣和ωforce'], 'correct': True}]`
pred_dicts (list): Contains a list of samples with the original
prompts. Such as
`[{'origin_prompt': '根据文章回答问题。你的答案应该尽可能3》…………',
'prediction': ' 光荣和ω-force\n', 'gold': ['光荣和ω-force']}]`
Returns:
list: The formatted prediction details.
"""
results = {}
for i in range(len(predictions)):
ppl_flag = False
result = {}
origin_prediction = copy.deepcopy(pred_dicts[i])
origin_prediction.pop('in-context examples', None)
origin_prediction.pop('prediction', None)
keys = copy.deepcopy(list(origin_prediction.keys()))
for key in keys:
if key.startswith('label:'):
ppl_flag = True
origin_prediction[key].pop('testing input', None)
new_key = key.replace('label: ', '')
origin_prediction[new_key] = origin_prediction.pop(key)
if ppl_flag:
results['type'] = 'PPL'
result['origin_prediction'] = origin_prediction
result['predictions'] = str(predictions[i])
result['references'] = str(references[i])
result['correct'] = str(predictions[i]) == str(references[i])
else:
results['type'] = 'GEN'
result['prompt'] = origin_prediction['origin_prompt']
result['origin_prediction'] = pred_dicts[i]['prediction']
result['predictions'] = details[i]['pred']
result['references'] = details[i]['answers']
result['correct'] = details[i]['correct']
results[str(i)] = result
return results
def calculate_bpb(self, pred_dicts: List):
"""This function is used to calculate the BPB (Bits Per Byte) for the
data. The correct BPB is obtained directly from the values in the
'predictions' file. The incorrect BPB is the average of the remaining
BPB values for each sample under different labels after subtracting the
correct BPB. The calculation of BPB (Bits Per Byte) is similar to PPL,
with the difference that it computes the additional bits needed on
average, in terms of character length, to encode the true sequence
based on the predictions. This calculation involves applying a
weighting factor based on the ratio of words to characters.
Args:
pred_dicts (list): Contains a list of samples with each options
and BPB scores.
Returns:
dict: Contains correct and incorrect bpb.
"""
incorrect_bpb_list = []
bpb_list = []
for pred_dict in pred_dicts:
preds = {
key: value
for key, value in pred_dict.items()
if key.startswith('label: ')
}
values = []
for item in preds.items():
values.append(item[1])
bpbs = [value['BPB'] for value in values]
incorrect_bpb_list.append(
(sum(bpbs) - min(bpbs)) / (len(bpbs) - 1))
bpb_list.append(statistics.mean(bpbs))
def filters(origins):
targets = [target for target in origins if not math.isnan(target)]
return targets
mean_incorrect = statistics.mean(filters(incorrect_bpb_list))
mean_correct = statistics.mean(filters(bpb_list))
return 100 * mean_correct, 100 * mean_incorrect
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Score Calculator') parser = argparse.ArgumentParser(description='Score Calculator')
......
...@@ -25,6 +25,7 @@ requests==2.31.0 ...@@ -25,6 +25,7 @@ requests==2.31.0
rouge rouge
rouge_chinese rouge_chinese
rouge_score rouge_score
sacrebleu
scikit_learn==1.2.1 scikit_learn==1.2.1
seaborn seaborn
sentence_transformers==2.2.2 sentence_transformers==2.2.2
......
...@@ -123,6 +123,12 @@ def parse_args(): ...@@ -123,6 +123,12 @@ def parse_args():
'Will be overrideen by the "retry" argument in the config.', 'Will be overrideen by the "retry" argument in the config.',
type=int, type=int,
default=2) default=2)
parser.add_argument(
'--dump-eval-details',
help='Whether to dump the evaluation details, including the '
'correctness of each sample, bpb, etc.',
action='store_true',
)
# set srun args # set srun args
slurm_parser = parser.add_argument_group('slurm_args') slurm_parser = parser.add_argument_group('slurm_args')
parse_slurm_args(slurm_parser) parse_slurm_args(slurm_parser)
...@@ -300,6 +306,8 @@ def main(): ...@@ -300,6 +306,8 @@ def main():
if args.dlc or args.slurm or cfg.get('eval', None) is None: if args.dlc or args.slurm or cfg.get('eval', None) is None:
fill_eval_cfg(cfg, args) fill_eval_cfg(cfg, args)
if args.dump_eval_details:
cfg.eval.runner.task.dump_details = True
if args.partition is not None: if args.partition is not None:
if RUNNERS.get(cfg.eval.runner.type) == SlurmRunner: if RUNNERS.get(cfg.eval.runner.type) == SlurmRunner:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment