Commit cac85281 authored by lintangsutawika's avatar lintangsutawika
Browse files

alternate step so that we can change num_fewshots

parent 53a4817e
...@@ -13,15 +13,16 @@ import numpy as np ...@@ -13,15 +13,16 @@ import numpy as np
from typing import List, Union from typing import List, Union
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY from lm_eval import utils
from lm_eval.api import HIGHER_IS_BETTER_REGISTRY
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY, HIGHER_IS_BETTER_REGISTRY
from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils
from lm_eval.prompts import get_prompt
from lm_eval.tasks import TASK_REGISTRY
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.api import samplers
@dataclass @dataclass
...@@ -422,6 +423,9 @@ class ConfigurableTask(Task): ...@@ -422,6 +423,9 @@ class ConfigurableTask(Task):
# else, if a config was passed as kwarg: use it # else, if a config was passed as kwarg: use it
if (self._config is None) and config: if (self._config is None) and config:
self._config = TaskConfig(**config) self._config = TaskConfig(**config)
elif config["num_fewshot"] != 0:
self._config.num_fewshot = config["num_fewshot"]
if self._config is None: if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg") raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
...@@ -880,7 +884,7 @@ def get_task_name_from_config(task_config): ...@@ -880,7 +884,7 @@ def get_task_name_from_config(task_config):
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict = { task_name_dict = {
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name}) task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0})
for task_name in task_name_list for task_name in task_name_list
if isinstance(task_name, str) if isinstance(task_name, str)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment