Commit e72ec96c authored by Baber's avatar Baber
Browse files

fix

parent d762e2aa
......@@ -29,11 +29,11 @@ repos:
- id: mixed-line-ending
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
rev: v0.12.5
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix]
args: [--fix]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
......
......@@ -8,7 +8,6 @@ import re
from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy
from functools import cached_property
from types import MethodType
from typing import TYPE_CHECKING, Any, Literal, overload
import datasets
......@@ -656,9 +655,16 @@ class ConfigurableTask(Task):
)
self.task_docs = self.eval_docs
for _method, fn in self.config._fn.items():
if hasattr(self, _method):
setattr(self, _method, MethodType(fn, self))
# for name, fn in self.config._fn.items():
# if hasattr(self, name):
# setattr(
# self,
# name,
# types.MethodType(
# lambda self, *args, _fn=fn, **kwargs: _fn(*args, **kwargs),
# self,
# ),
# )
self.runtime_checks(self.task_docs[0])
......@@ -974,6 +980,8 @@ class ConfigurableTask(Task):
# if self.prompt is not None:
# doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text
if callable(doc_to_text):
return doc_to_text(doc)
if doc_to_text in doc:
return doc[doc_to_text]
elif isinstance(doc_to_text, str):
......@@ -1019,6 +1027,8 @@ class ConfigurableTask(Task):
# if self.prompt is not None:
# doc_to_target = self.prompt
doc_to_target = doc_to_target or self.config.doc_to_target
if callable(doc_to_target):
doc_to_target(doc)
if doc_to_target in doc:
return doc[doc_to_target]
elif isinstance(doc_to_target, str):
......@@ -1280,6 +1290,8 @@ class ConfigurableTask(Task):
)
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(m.metric_name for m in self.config._metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
......
......@@ -10,7 +10,7 @@ import datasets
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import doc_to_closure, maybe_serialize
from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
......@@ -364,7 +364,7 @@ class TaskConfig:
@classmethod
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
fn = {k: doc_to_closure(v) for k, v in data.items() if callable(v)}
fn = {k: v for k, v in data.items() if callable(v)}
return cls(**data, _fn=fn)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment