Unverified Commit 3c606cb7 authored by bittersweet1999's avatar bittersweet1999 Committed by GitHub
Browse files

quick fix for postprocess pred extraction (#771)

parent 0c75f0f9
...@@ -22,6 +22,37 @@ from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, ...@@ -22,6 +22,37 @@ from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
task_abbr_from_cfg) task_abbr_from_cfg)
def extract_role_pred(s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The role
prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str, start)
if end_idx != -1:
end = end_idx
return s[start:end]
@TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run @TASKS.register_module(force=(__name__ == '__main__')) # A hack for script run
class OpenICLEvalTask(BaseTask): class OpenICLEvalTask(BaseTask):
"""OpenICL Evaluation Task. """OpenICL Evaluation Task.
...@@ -137,14 +168,14 @@ class OpenICLEvalTask(BaseTask): ...@@ -137,14 +168,14 @@ class OpenICLEvalTask(BaseTask):
'must be list.') 'must be list.')
if pred_list_flag: if pred_list_flag:
pred_strs = [[ pred_strs = [[
self._extract_role_pred(_pred, role.get('begin', None), extract_role_pred(_pred, role.get('begin', None),
role.get('end', None)) role.get('end', None))
for _pred in pred for _pred in pred
] for pred in pred_strs] ] for pred in pred_strs]
else: else:
pred_strs = [ pred_strs = [
self._extract_role_pred(pred, role.get('begin', None), extract_role_pred(pred, role.get('begin', None),
role.get('end', None)) role.get('end', None))
for pred in pred_strs for pred in pred_strs
] ]
...@@ -222,36 +253,6 @@ class OpenICLEvalTask(BaseTask): ...@@ -222,36 +253,6 @@ class OpenICLEvalTask(BaseTask):
mkdir_or_exist(osp.split(out_path)[0]) mkdir_or_exist(osp.split(out_path)[0])
mmengine.dump(result, out_path, ensure_ascii=False, indent=4) mmengine.dump(result, out_path, ensure_ascii=False, indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The
role prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str, start)
if end_idx != -1:
end = end_idx
return s[start:end]
def format_details(self, predictions, references, details, pred_dicts): def format_details(self, predictions, references, details, pred_dicts):
"""This function is responsible for formatting prediction details. """This function is responsible for formatting prediction details.
......
...@@ -4,7 +4,7 @@ import fnmatch ...@@ -4,7 +4,7 @@ import fnmatch
import os.path as osp import os.path as osp
import random import random
import time import time
from typing import List, Optional, Union from typing import List, Union
import mmengine import mmengine
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
...@@ -12,6 +12,7 @@ from mmengine.utils import mkdir_or_exist ...@@ -12,6 +12,7 @@ from mmengine.utils import mkdir_or_exist
from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS from opencompass.registry import ICL_EVALUATORS, MODELS, TEXT_POSTPROCESSORS
from opencompass.tasks.base import BaseTask from opencompass.tasks.base import BaseTask
from opencompass.tasks.openicl_eval import extract_role_pred
from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg, from opencompass.utils import (build_dataset_from_cfg, dataset_abbr_from_cfg,
get_infer_output_path, get_logger, get_infer_output_path, get_logger,
model_abbr_from_cfg, task_abbr_from_cfg) model_abbr_from_cfg, task_abbr_from_cfg)
...@@ -111,7 +112,9 @@ class SubjectiveEvalTask(BaseTask): ...@@ -111,7 +112,9 @@ class SubjectiveEvalTask(BaseTask):
filename = get_infer_output_path( filename = get_infer_output_path(
model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions')) model_cfg, dataset_cfg, osp.join(self.work_dir, 'predictions'))
root, ext = osp.splitext(filename) root, ext = osp.splitext(filename)
filename = root[:-2] + ext last_underscore_index = root.rfind('_')
root = root[:last_underscore_index]
filename = root + ext
# If take SubjectNaivePartition, get filename # If take SubjectNaivePartition, get filename
else: else:
filename = get_infer_output_path( filename = get_infer_output_path(
...@@ -161,9 +164,8 @@ class SubjectiveEvalTask(BaseTask): ...@@ -161,9 +164,8 @@ class SubjectiveEvalTask(BaseTask):
parser = LMTemplateParser(model_cfg['meta_template']) parser = LMTemplateParser(model_cfg['meta_template'])
role = parser.roles[eval_cfg['pred_role']] role = parser.roles[eval_cfg['pred_role']]
pred_strs = [ pred_strs = [
self._extract_role_pred(pred, role.get('begin', None), extract_role_pred(pred, role.get('begin', None),
role.get('end', None)) role.get('end', None)) for pred in pred_strs
for pred in pred_strs
] ]
# Postprocess predictions if necessary # Postprocess predictions if necessary
...@@ -238,36 +240,6 @@ class SubjectiveEvalTask(BaseTask): ...@@ -238,36 +240,6 @@ class SubjectiveEvalTask(BaseTask):
ensure_ascii=False, ensure_ascii=False,
indent=4) indent=4)
def _extract_role_pred(self, s: str, begin_str: Optional[str],
end_str: Optional[str]) -> str:
"""Extract the role prediction from the full prediction string. The
role prediction may be the substring between the begin and end string.
Args:
s (str): Full prediction string.
begin_str (str): The beginning string of the role
end_str (str): The ending string of the role.
Returns:
str: The extracted role prediction.
"""
start = 0
end = len(s)
if begin_str:
begin_idx = s.find(begin_str)
if begin_idx != -1:
start = begin_idx + len(begin_str)
if end_str:
# TODO: Support calling tokenizer for the accurate eos token
# and avoid such hardcode
end_idx = s.find(end_str[:1], start)
if end_idx != -1:
end = end_idx
return s[start:end]
def get_output_paths(self, file_extension: str = 'json') -> List[str]: def get_output_paths(self, file_extension: str = 'json') -> List[str]:
"""Get the paths to the output files. Every file should exist if the """Get the paths to the output files. Every file should exist if the
task succeeds. task succeeds.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment