Commit ea8dfbe8 authored by Baber's avatar Baber
Browse files

add strip_reasoning param

parent 6b3f3f7e
......@@ -300,6 +300,14 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
)
parser.add_argument(
"--strip_reasoning",
type=str,
nargs="?",
const="</think>",
default=False,
help="Strip reasoning blocks ending with specified token. Usage: --strip_reasoning (uses default '</think>') or --strip_reasoning '</reasoning>' (uses custom token)",
)
return parser
......@@ -472,6 +480,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
strip_reasoning=args.strip_reasoning,
metadata=metadata,
**request_caching_args,
)
......
......@@ -6,6 +6,7 @@ import re
from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import partial
from inspect import getsource
from typing import (
Any,
......@@ -679,6 +680,25 @@ class Task(abc.ABC):
setattr(self._config, "metric_list", [{"metric": metric_name}])
setattr(self._config, "process_results", None)
def overide_filter(self, filter_name: str, **kwargs) -> None:
"""
Override the default filters used for evaluation with custom filters.
"""
from lm_eval.api.registry import get_filter
if filter_name == "strip_reasoning":
if not self._filters:
self._filters = [
build_filter_ensemble(
"strip_reasoning", [["strip_reasoning", kwargs]]
)
]
else:
for f in self._filters:
f.filters.insert(
0, partial(get_filter("strip_reasoning"), **kwargs)
)
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"):
......
......@@ -75,6 +75,7 @@ def simple_evaluate(
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
strip_reasoning: Union[bool, str] = False,
metadata: Optional[dict] = None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -114,6 +115,10 @@ def simple_evaluate(
If True, write out an example document and model input for checking task integrity
:param log_samples: bool
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
:param evaluation_tracker: EvaluationTracker
An EvaluationTracker instance to track the evaluation process.
If None, no tracking will be done.
If provided, it will log the experiment arguments and results.
:param system_instruction: str
System instruction to be applied to the prompt
:param apply_chat_template: Union[bool, str]
......@@ -126,7 +131,9 @@ def simple_evaluate(
:param gen_kwargs: dict or comma-separated string
Arguments for model generation
Ignored for all tasks with loglikelihood output_type
:param verbosity: str
:param task_manager: TaskManager
TaskManager instance to manage tasks. If None, a new TaskManager will be created.
:param verbosity: str (deprecated - use LOGLEVEL environment variable)
Verbosity level for logging
:param predict_only: bool
If true only model outputs will be generated and returned. Metrics will not be evaluated
......@@ -138,6 +145,11 @@ def simple_evaluate(
Random seed for torch. If set to None, the seed will not be set.
:param fewshot_random_seed: int
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:param confirm_run_unsafe_code: bool
Whether to confirm running tasks marked as unsafe (code). If set to False, an error will be raised if an unsafe task is encountered.
:param strip_reasoning: bool or str
If set, will strip reasoning from task outputs. This is useful for tasks that have reasoning in the output.
The value of this argument will be passed to the `suffix` argument of the `strip_reasoning` filter.
:param metadata: dict
Additional metadata to be added to the task manager. Will get passed to the download function of the task.
......@@ -319,6 +331,15 @@ def simple_evaluate(
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
if strip_reasoning:
eval_logger.info(
f"Stripping reasoning from {task_name} task outputs."
)
task_obj.overide_filter(
"strip_reasoning",
**({"suffix": strip_reasoning} if strip_reasoning else {}),
)
adjusted_task_dict[task_name] = task_obj
return adjusted_task_dict
......
from functools import partial
from typing import List
from typing import List, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
......@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
filter_name: str, components: List[List[Union[str, dict, None]]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
......
......@@ -231,3 +231,19 @@ class MultiChoiceRegexFilter(RegexFilter):
filtered_resps.append(filtered)
return filtered_resps
@register_filter("strip_reasoning")
class StripReasoningFilter(Filter):
"""A filter that strips reasoning block from model responses and returns the last part of the response."""
def __init__(self, suffix: str = "</think>", **kwargs):
super().__init__(**kwargs)
assert suffix, "suffix is required but was falsy"
self.suffix = suffix
def apply(self, resps: list[list[str]], docs: list[dict], **kwargs):
def filter_set(inst: list[str]) -> list[str]:
return [r.split(self.suffix)[-1].strip() for r in inst]
return map(filter_set, resps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment