Commit 0bff23b4 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed sampler issue with new default num_fewshot value

parent e6d4ec39
......@@ -81,7 +81,7 @@ class TaskConfig(dict):
fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None
# runtime configuration options
num_fewshot: int = -1
num_fewshot: int = None
# scoring options
metric_list: list = None
output_type: str = "generate_until"
......@@ -359,7 +359,7 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc,
self.config.num_fewshot,
0 if self.config.num_fewshot is None else self.config.num_fewshot,
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -775,7 +775,7 @@ class ConfigurableTask(Task):
if self.config.fewshot_split is not None:
return self.dataset[self.config.fewshot_split]
else:
if self.config.num_fewshot > 0:
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning(
f"Task '{self.config.task}': "
"num_fewshot > 0 but fewshot_split is None. "
......
......@@ -625,7 +625,8 @@ def evaluate(
groups_agg[group]["alias"] = tab_string + group
for group_name, task_list in task_hierarchy.items():
num_fewshot[group_name] = num_fewshot[task_list[0]]
if task_list != []:
num_fewshot[group_name] = num_fewshot[task_list[0]]
results_dict = {
"results": dict(results_agg.items()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment