"examples/vscode:/vscode.git/clone" did not exist on "7f31142c2eaaaec11f7e7c461f300caf9c1c5661"
Commit 336cc455 authored by baberabb's avatar baberabb
Browse files

Merge remote-tracking branch 'origin/big-refactor' into big-refactor-mps

parents 7bb147b5 42f486ee
...@@ -81,7 +81,7 @@ class TaskConfig(dict): ...@@ -81,7 +81,7 @@ class TaskConfig(dict):
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None fewshot_config: dict = None
# runtime configuration options # runtime configuration options
num_fewshot: int = -1 num_fewshot: int = None
# scoring options # scoring options
metric_list: list = None metric_list: list = None
output_type: str = "generate_until" output_type: str = "generate_until"
...@@ -361,7 +361,7 @@ class Task(abc.ABC): ...@@ -361,7 +361,7 @@ class Task(abc.ABC):
# sample fewshot context #TODO: need to offset doc_id by rank now! # sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, doc,
self.config.num_fewshot, 0 if self.config.num_fewshot is None else self.config.num_fewshot,
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -777,7 +777,7 @@ class ConfigurableTask(Task): ...@@ -777,7 +777,7 @@ class ConfigurableTask(Task):
if self.config.fewshot_split is not None: if self.config.fewshot_split is not None:
return self.dataset[self.config.fewshot_split] return self.dataset[self.config.fewshot_split]
else: else:
if self.config.num_fewshot > 0: if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
eval_logger.warning( eval_logger.warning(
f"Task '{self.config.task}': " f"Task '{self.config.task}': "
"num_fewshot > 0 but fewshot_split is None. " "num_fewshot > 0 but fewshot_split is None. "
......
...@@ -260,7 +260,7 @@ def evaluate( ...@@ -260,7 +260,7 @@ def evaluate(
if "num_fewshot" in configs[task_name]: if "num_fewshot" in configs[task_name]:
n_shot = configs[task_name]["num_fewshot"] n_shot = configs[task_name]["num_fewshot"]
else: else:
n_shot = -1 n_shot = 0
num_fewshot[task_name] = n_shot num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]: if "task_alias" in configs[task_name]:
...@@ -440,7 +440,6 @@ def evaluate( ...@@ -440,7 +440,6 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation ### Get task ordering for correct sample-wide aggregation
group_to_task = {} group_to_task = {}
for group in task_hierarchy.keys(): for group in task_hierarchy.keys():
...@@ -451,7 +450,6 @@ def evaluate( ...@@ -451,7 +450,6 @@ def evaluate(
group_to_task[group] = task_hierarchy[group].copy() group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]: for task in task_hierarchy[group]:
if task in task_order: if task in task_order:
task_order[task] += 1 task_order[task] += 1
else: else:
...@@ -498,9 +496,7 @@ def evaluate( ...@@ -498,9 +496,7 @@ def evaluate(
results[task_name][metric + "_stderr" + "," + key] = stderr(items) results[task_name][metric + "_stderr" + "," + key] = stderr(items)
if bool(results): if bool(results):
for group, task_list in reversed(task_hierarchy.items()): for group, task_list in reversed(task_hierarchy.items()):
if task_list == []: if task_list == []:
total_size = results[group]["samples"] total_size = results[group]["samples"]
else: else:
...@@ -520,7 +516,6 @@ def evaluate( ...@@ -520,7 +516,6 @@ def evaluate(
for metric in [ for metric in [
key for key in metrics.keys() if "_stderr" not in key key for key in metrics.keys() if "_stderr" not in key
]: ]:
stderr = "_stderr,".join(metric.split(",")) stderr = "_stderr,".join(metric.split(","))
stderr_score = results[task][stderr] stderr_score = results[task][stderr]
var_score = stderr_score**2 var_score = stderr_score**2
...@@ -557,11 +552,9 @@ def evaluate( ...@@ -557,11 +552,9 @@ def evaluate(
results[group]["samples"] = total_size results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias): def print_tasks(task_hierarchy, task_order, task_version, task_group_alias):
results_agg = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items(): for group_name, task_list in task_hierarchy.items():
order = task_order[group_name] order = task_order[group_name]
results_agg[group_name] = results[group_name].copy() results_agg[group_name] = results[group_name].copy()
results_agg[group_name]["tab"] = order results_agg[group_name]["tab"] = order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment