Commit 0bff23b4 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed sampler issue with new default num_fewshot value

parent e6d4ec39
...@@ -81,7 +81,7 @@ class TaskConfig(dict): ...@@ -81,7 +81,7 @@ class TaskConfig(dict):
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None fewshot_config: dict = None
# runtime configuration options # runtime configuration options
num_fewshot: int = -1 num_fewshot: int = None
# scoring options # scoring options
metric_list: list = None metric_list: list = None
output_type: str = "generate_until" output_type: str = "generate_until"
...@@ -359,7 +359,7 @@ class Task(abc.ABC): ...@@ -359,7 +359,7 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now! # sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, doc,
self.config.num_fewshot, 0 if self.config.num_fewshot is None else self.config.num_fewshot,
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -775,7 +775,7 @@ class ConfigurableTask(Task): ...@@ -775,7 +775,7 @@ class ConfigurableTask(Task):
if self.config.fewshot_split is not None: if self.config.fewshot_split is not None:
return self.dataset[self.config.fewshot_split] return self.dataset[self.config.fewshot_split]
else: else:
if self.config.num_fewshot > 0: if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning( eval_logger.warning(
f"Task '{self.config.task}': " f"Task '{self.config.task}': "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
......
...@@ -625,6 +625,7 @@ def evaluate( ...@@ -625,6 +625,7 @@ def evaluate(
groups_agg[group]["alias"] = tab_string + group groups_agg[group]["alias"] = tab_string + group
for group_name, task_list in task_hierarchy.items(): for group_name, task_list in task_hierarchy.items():
if task_list != []:
num_fewshot[group_name] = num_fewshot[task_list[0]] num_fewshot[group_name] = num_fewshot[task_list[0]]
results_dict = { results_dict = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment