Commit cac85281 authored by lintangsutawika's avatar lintangsutawika
Browse files

alternate step so that we can change num_fewshots

parent 53a4817e
......@@ -13,15 +13,16 @@ import numpy as np
from typing import List, Union
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api import HIGHER_IS_BETTER_REGISTRY
from lm_eval import utils
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.tasks import TASK_REGISTRY
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api import samplers
@dataclass
......@@ -422,6 +423,9 @@ class ConfigurableTask(Task):
# else, if a config was passed as kwarg: use it
if (self._config is None) and config:
self._config = TaskConfig(**config)
elif config["num_fewshot"] != 0:
self._config.num_fewshot = config["num_fewshot"]
if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
......@@ -880,7 +884,7 @@ def get_task_name_from_config(task_config):
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict = {
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0})
for task_name in task_name_list
if isinstance(task_name, str)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment