Commit c4f0bf75 authored by lintangsutawika's avatar lintangsutawika
Browse files

pre-commit reformat

parent 4ccd2ec6
......@@ -221,8 +221,6 @@ def evaluate(
task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
# store the aggregation for aggregating across tasks in the same group
sample_agg_fn = collections.defaultdict(dict)
# get lists of each type of request
for task_name, task in task_dict.items():
......@@ -469,7 +467,6 @@ def evaluate(
for group, task_list in reversed(task_hierarchy.items()):
versions[group] = "N/A"
task_score_dict = {}
total_size = 0
for task in task_list:
metrics = results[task]
......@@ -479,7 +476,9 @@ def evaluate(
# current_size = 1
all_stderr = []
for metric in [key for key in metrics.keys() if "_stderr" not in key]:
for metric in [
key for key in metrics.keys() if "_stderr" not in key
]:
stderr = "_stderr,".join(metric.split(","))
stderr_score = results[task][stderr]
......@@ -488,10 +487,22 @@ def evaluate(
all_stderr.append(stderr)
if metric in results[group]:
results[group][metric] = (results[group][metric]*total_size + metric_score*current_size)/(total_size+current_size)
results[group][metric] = (
results[group][metric] * total_size
+ metric_score * current_size
) / (total_size + current_size)
# $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
results[group][stderr] = ((total_size-1)*results[group][stderr]+(current_size-1)*stderr_score)/(total_size + current_size - 1) \
+ total_size*current_size/((total_size+current_size)*(total_size+current_size-1))*(results[group][metric] - metric_score)**2
results[group][stderr] = (
(total_size - 1) * results[group][stderr]
+ (current_size - 1) * stderr_score
) / (
total_size + current_size - 1
) + total_size * current_size / (
(total_size + current_size)
* (total_size + current_size - 1)
) * (
results[group][metric] - metric_score
) ** 2
else:
results[group][metric] = metric_score
results[group][stderr] = stderr_score
......@@ -500,7 +511,7 @@ def evaluate(
for stderr in all_stderr:
results[group][stderr] = np.sqrt(results[group][stderr])
results[group]["samples"] = total_size
for task_name, task in task_dict.items():
......
......@@ -137,8 +137,8 @@ if __name__ == "__main__":
yaml.dump(
{
"group": f"mmlu_{args.group_prefix}",
"task": [f"mmlu_{category}" for category in ALL_CATEGORIES]
},
"task": [f"mmlu_{category}" for category in ALL_CATEGORIES],
},
yaml_file,
default_flow_style=False
default_flow_style=False,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment