Commit 108674ed authored by Baber's avatar Baber
Browse files

add FewshotConfig

parent c5aa5cf0
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from typing import Iterable, List, Union
from lm_eval.api.instance import Instance
......@@ -40,7 +40,7 @@ class FilterEnsemble:
"""
name: str
filters: List[Callable[[], Filter]]
filters: List[type[Filter]]
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
......
......@@ -90,6 +90,12 @@ class FilterConfig:
kwargs: Optional[dict] = None
@dataclass
class FewshotConfig:
sampler: str
samples: list[dict]
@dataclass
class TaskConfig(dict):
# task naming/registry
......@@ -185,6 +191,9 @@ class TaskConfig(dict):
metrics = []
if self.metric_list is None:
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend(
MetricConfig(
name=metric_name,
......@@ -261,6 +270,35 @@ class TaskConfig(dict):
)
return metrics
def get_filters(self):
if self.filter_list is not None:
_filter_list = []
if isinstance(self.filter_list, dict):
for filter_config in self.filter_list:
_filter_list.append(
build_filter_ensemble(
filter_name=filter_config["name"],
components=[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
)
_filter_list = [build_filter_ensemble("none", [["take_first", None]])]
return _filter_list
def __getitem__(self, item):
return getattr(self, item)
......@@ -908,31 +946,33 @@ class ConfigurableTask(Task):
self._training_docs = None
self._fewshot_docs = None
if self.config.filter_list is not None:
self._filters = []
if isinstance(self.config.filter_list, dict):
for filter_config in self.config.filter_list:
self._filters.append(
build_filter_ensemble(
filter_config["name"],
[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
)
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self._filters = self.config.get_filters()
# if self.config.filter_list is not None:
# self._filters = []
# if isinstance(self.config.filter_list, dict):
# for filter_config in self.config.filter_list:
# self._filters.append(
# build_filter_ensemble(
# filter_config["name"],
# [
# [
# {
# key: function[key]
# for key in function
# if key != "function"
# }
# ]
# for function in filter_config["filter"]
# ],
# )
# )
# else:
# # TODO: handle repeats in a more general way rather than just discarding
# eval_logger.debug(
# "No custom filters defined. Using default 'take_first' filter for handling repeats."
# )
# self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self.config.use_prompt}")
......
......@@ -388,7 +388,8 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
dic = result_dict[column][k]
version = result_dict["versions"].get(k, " N/A")
n = str(result_dict.get("n-shot", " ").get(k, " "))
higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
# TODO: fix this
# higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
if "alias" in dic:
k = dic.pop("alias")
......@@ -401,7 +402,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
if m.endswith("_stderr"):
continue
hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
# hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
# TODO: fix
hib = "↑"
v = "%.4f" % v if isinstance(v, float) else v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment