Commit 9c647fc1 authored by Baber's avatar Baber
Browse files

add FewshotConfig

parent 28c78d30
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Iterable, List, Union from typing import Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -40,7 +40,7 @@ class FilterEnsemble: ...@@ -40,7 +40,7 @@ class FilterEnsemble:
""" """
name: str name: str
filters: List[Callable[[], Filter]] filters: List[type[Filter]]
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
......
...@@ -90,6 +90,12 @@ class FilterConfig: ...@@ -90,6 +90,12 @@ class FilterConfig:
kwargs: Optional[dict] = None kwargs: Optional[dict] = None
@dataclass
class FewshotConfig:
sampler: str
samples: list[dict]
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
...@@ -185,6 +191,9 @@ class TaskConfig(dict): ...@@ -185,6 +191,9 @@ class TaskConfig(dict):
metrics = [] metrics = []
if self.metric_list is None: if self.metric_list is None:
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend( metrics.extend(
MetricConfig( MetricConfig(
name=metric_name, name=metric_name,
...@@ -261,6 +270,35 @@ class TaskConfig(dict): ...@@ -261,6 +270,35 @@ class TaskConfig(dict):
) )
return metrics return metrics
def get_filters(self):
if self.filter_list is not None:
_filter_list = []
if isinstance(self.filter_list, dict):
for filter_config in self.filter_list:
_filter_list.append(
build_filter_ensemble(
filter_name=filter_config["name"],
components=[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
)
_filter_list = [build_filter_ensemble("none", [["take_first", None]])]
return _filter_list
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -908,31 +946,33 @@ class ConfigurableTask(Task): ...@@ -908,31 +946,33 @@ class ConfigurableTask(Task):
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
if self.config.filter_list is not None: self._filters = self.config.get_filters()
self._filters = []
if isinstance(self.config.filter_list, dict): # if self.config.filter_list is not None:
for filter_config in self.config.filter_list: # self._filters = []
self._filters.append( # if isinstance(self.config.filter_list, dict):
build_filter_ensemble( # for filter_config in self.config.filter_list:
filter_config["name"], # self._filters.append(
[ # build_filter_ensemble(
[ # filter_config["name"],
{ # [
key: function[key] # [
for key in function # {
if key != "function" # key: function[key]
} # for key in function
] # if key != "function"
for function in filter_config["filter"] # }
], # ]
) # for function in filter_config["filter"]
) # ],
else: # )
# TODO: handle repeats in a more general way rather than just discarding # )
eval_logger.debug( # else:
"No custom filters defined. Using default 'take_first' filter for handling repeats." # # TODO: handle repeats in a more general way rather than just discarding
) # eval_logger.debug(
self._filters = [build_filter_ensemble("none", [["take_first", None]])] # "No custom filters defined. Using default 'take_first' filter for handling repeats."
# )
# self._filters = [build_filter_ensemble("none", [["take_first", None]])]
if self.config.use_prompt is not None: if self.config.use_prompt is not None:
eval_logger.info(f"loading prompt {self.config.use_prompt}") eval_logger.info(f"loading prompt {self.config.use_prompt}")
......
...@@ -405,7 +405,8 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -405,7 +405,8 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
dic = result_dict[column][k] dic = result_dict[column][k]
version = result_dict["versions"].get(k, " N/A") version = result_dict["versions"].get(k, " N/A")
n = str(result_dict.get("n-shot", " ").get(k, " ")) n = str(result_dict.get("n-shot", " ").get(k, " "))
higher_is_better = result_dict.get("higher_is_better", {}).get(k, {}) # TODO: fix this
# higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
if "alias" in dic: if "alias" in dic:
k = dic.pop("alias") k = dic.pop("alias")
...@@ -418,7 +419,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -418,7 +419,9 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
if m.endswith("_stderr"): if m.endswith("_stderr"):
continue continue
hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "") # hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
# TODO: fix
hib = "↑"
v = "%.4f" % v if isinstance(v, float) else v v = "%.4f" % v if isinstance(v, float) else v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment