Commit 408e7802 authored by cjlovering's avatar cjlovering
Browse files

Save logging information even if the saving example is false

parent 960f3780
...@@ -778,7 +778,7 @@ class PromptSourceTask(Task): ...@@ -778,7 +778,7 @@ class PromptSourceTask(Task):
), "Unexpected metric. Add it, or use a task-specific solution." ), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU": if metric == "BLEU":
out["bleu"] = (target, pred) out["bleu"] = (target, pred)
if metric == "ROUGE": elif metric == "ROUGE":
# TODO: This computes all rouge sub-metrics. Find a generic # TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra # way to handle user specified rouge sub-metrics to avoid extra
# compute. # compute.
......
...@@ -247,7 +247,7 @@ def evaluate( ...@@ -247,7 +247,7 @@ def evaluate(
for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items(): for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items():
per_doc_requests.sort(key=lambda x: x[0]) per_doc_requests.sort(key=lambda x: x[0])
per_doc_results = [x[1] for x in per_doc_requests] per_doc_results = [x[1] for x in per_doc_requests]
logging_info = [x[2] for x in per_doc_requests][0] fewshot_logging_info = [x[2] for x in per_doc_requests][0]
task = task_dict[task_prompt_name] task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)] doc = docs[(task_prompt_name, doc_id)]
...@@ -255,15 +255,15 @@ def evaluate( ...@@ -255,15 +255,15 @@ def evaluate(
output = task.process_results(doc, per_doc_results) output = task.process_results(doc, per_doc_results)
if task.save_examples: if task.save_examples:
metrics, example = output metrics, example = output
example.update(fewshot_logging_info)
if logging_info:
# This has the fewshot information like fewshot_idx.
example.update(logging_info)
example.update(task.get_logging_info()) example.update(task.get_logging_info())
examples.append(example) examples.append(example)
else: else:
metrics = output metrics = output
example = fewshot_logging_info
example.update(task.get_logging_info())
examples.append(example)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_prompt_name, metric)].append(value) vals[(task_prompt_name, metric)].append(value)
...@@ -296,6 +296,7 @@ def evaluate( ...@@ -296,6 +296,7 @@ def evaluate(
results[task_prompt_name][metric + "_stderr"] = stderr(items) results[task_prompt_name][metric + "_stderr"] = stderr(items)
_metric_results[metric + "_stderr"] = stderr(items) _metric_results[metric + "_stderr"] = stderr(items)
metric_results.append(_metric_results) metric_results.append(_metric_results)
return { return {
# List of results that tracks the averages per model and prompt. # List of results that tracks the averages per model and prompt.
"results": metric_results, "results": metric_results,
......
...@@ -126,10 +126,7 @@ class CoQA(PromptSourceTask): ...@@ -126,10 +126,7 @@ class CoQA(PromptSourceTask):
} }
if self.save_examples: if self.save_examples:
example = { example = {"target": target, "pred": pred}
"f1": scores["f1"],
"em": scores["em"],
}
return out, example return out, example
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment