"profiler/vscode:/vscode.git/clone" did not exist on "1ceb3b0927a79e74b3fd0231c64292d42a709801"
Unverified Commit e5bc4354 authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #4 from cjlovering/cjlovering/output_examples

Enable saving all the examples.
parents 18af502b 408e7802
...@@ -654,11 +654,21 @@ class PromptSourceTask(Task): ...@@ -654,11 +654,21 @@ class PromptSourceTask(Task):
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`. *and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
""" """
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"]) CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"])
CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE"])
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None): SPLIT = None
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
prompt=None,
save_examples=True,
):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
self.save_examples = save_examples
def stopping_criteria(self) -> Optional[str]: def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end. """Denote where the generation should end.
...@@ -752,12 +762,11 @@ class PromptSourceTask(Task): ...@@ -752,12 +762,11 @@ class PromptSourceTask(Task):
for metric in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
assert ( assert (
metric in self.CONFIGURED_PS_METRICS metric in self.CONFIGURED_RANKED_CHOICE_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution." ), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy": if metric == "Accuracy":
out["acc"] = pred == target out["acc"] = pred == target
# TODO: Add metrics here. # TODO: Add metrics here.
return out
else: else:
# If not, then this is a generation prompt. # If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings. # NOTE: In the future, target will be a list of strings.
...@@ -765,11 +774,11 @@ class PromptSourceTask(Task): ...@@ -765,11 +774,11 @@ class PromptSourceTask(Task):
out = {} out = {}
for metric in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
assert ( assert (
metric in self.CONFIGURED_PS_METRICS metric in self.CONFIGURED_GENERATION_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution." ), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU": if metric == "BLEU":
out["bleu"] = (target, pred) out["bleu"] = (target, pred)
if metric == "ROUGE": elif metric == "ROUGE":
# TODO: This computes all rouge sub-metrics. Find a generic # TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra # way to handle user specified rouge sub-metrics to avoid extra
# compute. # compute.
...@@ -778,15 +787,21 @@ class PromptSourceTask(Task): ...@@ -778,15 +787,21 @@ class PromptSourceTask(Task):
rouge_scores = utils.flatten(rouge_scores) rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict. # Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores} out = {**out, **rouge_scores}
print(out)
return out # TODO: Wrap process results s.t. override impl do not
# override the save examples.
if self.save_examples:
example = {
"pred": pred,
"target": target,
"answer_choices_list": answer_choices_list,
}
return out, example
return out
def higher_is_better(self): def higher_is_better(self):
out = {} out = {}
for metric in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy": if metric == "Accuracy":
out["acc"] = True out["acc"] = True
if metric == "BLEU": if metric == "BLEU":
...@@ -813,9 +828,6 @@ class PromptSourceTask(Task): ...@@ -813,9 +828,6 @@ class PromptSourceTask(Task):
def aggregation(self): def aggregation(self):
out = {} out = {}
for metric in self.prompt.metadata.metrics: for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy": if metric == "Accuracy":
out["acc"] = mean out["acc"] = mean
if metric == "BLEU": if metric == "BLEU":
...@@ -839,6 +851,122 @@ class PromptSourceTask(Task): ...@@ -839,6 +851,122 @@ class PromptSourceTask(Task):
out["rougeLsum_fmeasure"] = mean out["rougeLsum_fmeasure"] = mean
return out return out
def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return self._get_fewshot_examples(self._training_docs, k, rnd)
def _get_fewshot_examples(self, docs, k, rnd):
fewshot_idx = rnd.sample(list(np.arange(len(docs))), k)
return [docs[idx] for idx in fewshot_idx], [int(idx) for idx in fewshot_idx]
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train"
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
if self.has_validation_docs():
fewshotsource = "val"
elif self.test_docs():
fewshotsource = "test"
fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
)
fewshotex, fewshotidx = [
(shot, idx)
for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc
]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = (
fewshotex[:num_fewshot],
fewshotidx[:num_fewshot],
)
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
ctx = description + labeled_examples + example
return (
ctx,
{
"fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource,
"fewshot_num": num_fewshot,
"ctx": ctx,
},
)
def get_logging_info(self):
return {
"fixed_answer_choice_list": self.prompt.get_fixed_answer_choices_list(),
"dataset_path": self.DATASET_PATH,
"dataset_name": self.DATASET_NAME,
"subset": self.SPLIT,
"prompt_name": self.prompt.get_name(),
"prompt_id": self.prompt.get_id(),
"prompt_jinja": self.prompt.jinja,
"prompt_original_task": self.prompt.metadata.original_task,
# Placeholder for comment in post-processing.
"comment": "",
}
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
......
...@@ -173,10 +173,6 @@ def evaluate( ...@@ -173,10 +173,6 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_prompt_name, task in task_dict_items: for task_prompt_name, task in task_dict_items:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions[task_prompt_name] = task.VERSION versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable # default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
...@@ -188,7 +184,7 @@ def evaluate( ...@@ -188,7 +184,7 @@ def evaluate(
raise RuntimeError("Task has neither test_docs nor validation_docs") raise RuntimeError("Task has neither test_docs nor validation_docs")
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func()) task_docs = list(enumerate(list(task_doc_func())))
rnd = random.Random() rnd = random.Random()
rnd.seed(42) rnd.seed(42)
rnd.shuffle(task_docs) rnd.shuffle(task_docs)
...@@ -199,14 +195,17 @@ def evaluate( ...@@ -199,14 +195,17 @@ def evaluate(
else "" else ""
) )
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): for doc_id, (original_doc_id, doc) in enumerate(
itertools.islice(task_docs, 0, limit)
):
if task.invalid_doc_for_prompt(doc): if task.invalid_doc_for_prompt(doc):
continue continue
docs[(task_prompt_name, doc_id)] = doc docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context( ctx, fewshotex_logging_info = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
) )
fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx) reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
...@@ -215,7 +214,7 @@ def evaluate( ...@@ -215,7 +214,7 @@ def evaluate(
# i: index in requests for a single task instance # i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append( requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id) (i, task_prompt_name, doc, doc_id, fewshotex_logging_info)
) )
# all responses for each (task, doc) # all responses for each (task, doc)
...@@ -234,33 +233,57 @@ def evaluate( ...@@ -234,33 +233,57 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs) x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
] ]
for resp, (i, task_prompt_name, doc, doc_id) in zip( for resp, (i, task_prompt_name, doc, doc_id, fewshotex_logging_info) in zip(
resps, requests_origin[reqtype] resps, requests_origin[reqtype]
): ):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp)) process_res_queue[(task_prompt_name, doc_id)].append(
(i, resp, fewshotex_logging_info)
)
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
for (task_prompt_name, doc_id), requests in process_res_queue.items(): examples = []
requests.sort(key=lambda x: x[0]) for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items():
requests = [x[1] for x in requests] per_doc_requests.sort(key=lambda x: x[0])
per_doc_results = [x[1] for x in per_doc_requests]
fewshot_logging_info = [x[2] for x in per_doc_requests][0]
task = task_dict[task_prompt_name] task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)] doc = docs[(task_prompt_name, doc_id)]
metrics = task.process_results(doc, requests) output = task.process_results(doc, per_doc_results)
if task.save_examples:
metrics, example = output
example.update(fewshot_logging_info)
example.update(task.get_logging_info())
examples.append(example)
else:
metrics = output
example = fewshot_logging_info
example.update(task.get_logging_info())
examples.append(example)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_prompt_name, metric)].append(value) vals[(task_prompt_name, metric)].append(value)
# aggregate results # aggregate results
metric_results = []
for (task_prompt_name, metric), items in vals.items(): for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+") task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_prompt_name] task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items) results[task_prompt_name][metric] = task.aggregation()[metric](items)
_metric_results = {
"task_name": task_name,
"prompt_name": prompt_name,
metric: task.aggregation()[metric](items),
**task.get_logging_info(),
}
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric( stderr = lm_eval.metrics.stderr_for_metric(
...@@ -271,8 +294,18 @@ def evaluate( ...@@ -271,8 +294,18 @@ def evaluate(
) )
if stderr is not None: if stderr is not None:
results[task_prompt_name][metric + "_stderr"] = stderr(items) results[task_prompt_name][metric + "_stderr"] = stderr(items)
_metric_results[metric + "_stderr"] = stderr(items)
return {"results": dict(results), "versions": dict(versions)} metric_results.append(_metric_results)
return {
# List of results that tracks the averages per model and prompt.
"results": metric_results,
"versions": dict(versions),
# List of all prompt x doc examples with additional information in it.
"examples": examples,
# Original results used for generating the table when running this file.
"table_results": dict(results),
}
def make_table(result_dict): def make_table(result_dict):
...@@ -293,7 +326,7 @@ def make_table(result_dict): ...@@ -293,7 +326,7 @@ def make_table(result_dict):
] ]
values = [] values = []
for k, dic in result_dict["results"].items(): for k, dic in result_dict["table_results"].items():
version = result_dict["versions"][k] version = result_dict["versions"][k]
for m, v in dic.items(): for m, v in dic.items():
if m.endswith("_stderr"): if m.endswith("_stderr"):
......
...@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask): ...@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
""" """
target = self.doc_to_target(doc).strip() target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0] pred = results[0].strip().split("\n")[0]
print("*" * 80)
print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred} END PRED")
print("*" * 80)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred) scores = self.compute_scores([target], pred)
return { out = {
"f1": scores["f1"], "f1": scores["f1"],
"em": scores["em"], "em": scores["em"],
} }
if self.save_examples:
example = {"target": target, "pred": pred}
return out, example
return out
def higher_is_better(self): def higher_is_better(self):
return { return {
"f1": True, "f1": True,
......
...@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING) ...@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True) parser.add_argument("--model", required=True)
parser.add_argument('--model_args', default="") parser.add_argument("--model_args", default="")
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None) parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument('--device', type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument('--output_path', default=None) parser.add_argument("--output_path", default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument("--limit", type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument("--no_cache", action="store_true")
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--description_dict_path", default=None)
parser.add_argument('--check_integrity', action="store_true") parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
assert not args.provide_description # not implemented assert not args.provide_description # not implemented
if args.limit: if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.") print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks": if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS task_names = tasks.ALL_TASKS
...@@ -38,7 +40,7 @@ def main(): ...@@ -38,7 +40,7 @@ def main():
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
...@@ -51,11 +53,12 @@ def main(): ...@@ -51,11 +53,12 @@ def main():
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
check_integrity=args.check_integrity check_integrity=args.check_integrity,
) )
print(results)
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)
if args.output_path: if args.output_path:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment