Commit 708b160d authored by Baber's avatar Baber
Browse files

strip thinking

parent 35be7100
...@@ -301,12 +301,12 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -301,12 +301,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""", help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
) )
parser.add_argument( parser.add_argument(
"--strip_reasoning", "--strip_thinking",
type=str, type=str,
nargs="?", nargs="?",
const="</think>", const="</think>",
default=False, default=False,
help="Strip reasoning blocks ending with specified token. Usage: --strip_reasoning (uses default '</think>') or --strip_reasoning '</reasoning>' (uses custom token)", help="Strip thinking blocks ending with specified token. Usage: --strip_thinking (uses default '</think>') or --strip_thinking '</thinking>' (uses custom token)",
) )
return parser return parser
...@@ -480,7 +480,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -480,7 +480,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2], torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3], fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code, confirm_run_unsafe_code=args.confirm_run_unsafe_code,
strip_reasoning=args.strip_reasoning, strip_thinking=args.strip_thinking,
metadata=metadata, metadata=metadata,
**request_caching_args, **request_caching_args,
) )
......
...@@ -686,18 +686,16 @@ class Task(abc.ABC): ...@@ -686,18 +686,16 @@ class Task(abc.ABC):
""" """
from lm_eval.api.registry import get_filter from lm_eval.api.registry import get_filter
if filter_name == "strip_reasoning": if filter_name == "strip_thinking":
if not self._filters: if not self._filters:
self._filters = [ self._filters = [
build_filter_ensemble( build_filter_ensemble(
"strip_reasoning", [["strip_reasoning", kwargs]] "strip_thinking", [["strip_thinking", kwargs]]
) )
] ]
else: else:
for f in self._filters: for f in self._filters:
f.filters.insert( f.filters.insert(0, partial(get_filter("strip_thinking"), **kwargs))
0, partial(get_filter("strip_reasoning"), **kwargs)
)
def set_fewshot_seed(self, seed: Optional[int] = None) -> None: def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed) self.fewshot_rnd = random.Random(seed)
......
...@@ -75,7 +75,7 @@ def simple_evaluate( ...@@ -75,7 +75,7 @@ def simple_evaluate(
torch_random_seed: int = 1234, torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234, fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False, confirm_run_unsafe_code: bool = False,
strip_reasoning: Union[bool, str] = False, strip_thinking: Union[bool, str] = False,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -147,9 +147,9 @@ def simple_evaluate( ...@@ -147,9 +147,9 @@ def simple_evaluate(
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None. Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
:param confirm_run_unsafe_code: bool :param confirm_run_unsafe_code: bool
Whether to confirm running tasks marked as unsafe (code). If set to False, an error will be raised if an unsafe task is encountered. Whether to confirm running tasks marked as unsafe (code). If set to False, an error will be raised if an unsafe task is encountered.
:param strip_reasoning: bool or str :param strip_thinking: bool or str
If set, will strip reasoning from task outputs. This is useful for tasks that have reasoning in the output. If set, will strip reasoning from task outputs. This is useful for tasks that have reasoning in the output.
The value of this argument will be passed to the `suffix` argument of the `strip_reasoning` filter. The value of this argument will be passed to the `suffix` argument of the `strip_thinking` filter which is applied to the generation outputs.
:param metadata: dict :param metadata: dict
Additional metadata to be added to the task manager. Will get passed to the download function of the task. Additional metadata to be added to the task manager. Will get passed to the download function of the task.
...@@ -331,13 +331,13 @@ def simple_evaluate( ...@@ -331,13 +331,13 @@ def simple_evaluate(
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=fewshot_random_seed) task_obj.set_fewshot_seed(seed=fewshot_random_seed)
if strip_reasoning and task_obj.OUTPUT_TYPE == "generate_until": if strip_thinking and task_obj.OUTPUT_TYPE == "generate_until":
eval_logger.info( eval_logger.info(
f"Stripping reasoning from {task_name} task outputs using {strip_reasoning}." f"Stripping thinking blocks from {task_name} task outputs using {strip_thinking}."
) )
task_obj.overide_filter( task_obj.overide_filter(
"strip_reasoning", "strip_thinking",
**({"suffix": strip_reasoning} if strip_reasoning else {}), **({"suffix": strip_thinking} if strip_thinking else {}),
) )
adjusted_task_dict[task_name] = task_obj adjusted_task_dict[task_name] = task_obj
......
...@@ -233,7 +233,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -233,7 +233,7 @@ class MultiChoiceRegexFilter(RegexFilter):
return filtered_resps return filtered_resps
@register_filter("strip_reasoning") @register_filter("strip_thinking")
class StripReasoningFilter(Filter): class StripReasoningFilter(Filter):
"""A filter that strips reasoning block from model responses and returns the last part of the response.""" """A filter that strips reasoning block from model responses and returns the last part of the response."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment