"examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py" did not exist on "6ba0efb9a188b08f5b46565a87c0b3da7ff46af4"
Unverified Commit 025547c9 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #6 from cjlovering/master

Update with new PR
parents 54999199 e5bc4354
......@@ -654,11 +654,21 @@ class PromptSourceTask(Task):
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
"""
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"])
CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE"])
SPLIT = None
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
prompt=None,
save_examples=True,
):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
self.save_examples = save_examples
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
......@@ -752,12 +762,11 @@ class PromptSourceTask(Task):
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
metric in self.CONFIGURED_RANKED_CHOICE_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = pred == target
# TODO: Add metrics here.
return out
else:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
......@@ -765,11 +774,11 @@ class PromptSourceTask(Task):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
metric in self.CONFIGURED_GENERATION_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
elif metric == "ROUGE":
# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
......@@ -778,15 +787,21 @@ class PromptSourceTask(Task):
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
print(out)
return out
# TODO: Wrap process results s.t. override impl do not
# override the save examples.
if self.save_examples:
example = {
"pred": pred,
"target": target,
"answer_choices_list": answer_choices_list,
}
return out, example
return out
def higher_is_better(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = True
if metric == "BLEU":
......@@ -813,9 +828,6 @@ class PromptSourceTask(Task):
def aggregation(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = mean
if metric == "BLEU":
......@@ -839,6 +851,122 @@ class PromptSourceTask(Task):
out["rougeLsum_fmeasure"] = mean
return out
def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return self._get_fewshot_examples(self._training_docs, k, rnd)
def _get_fewshot_examples(self, docs, k, rnd):
fewshot_idx = rnd.sample(list(np.arange(len(docs))), k)
return [docs[idx] for idx in fewshot_idx], [int(idx) for idx in fewshot_idx]
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train"
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
if self.has_validation_docs():
fewshotsource = "val"
elif self.test_docs():
fewshotsource = "test"
fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
)
fewshotex, fewshotidx = [
(shot, idx)
for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc
]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = (
fewshotex[:num_fewshot],
fewshotidx[:num_fewshot],
)
labeled_examples = (
"\n\n".join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
)
example = self.doc_to_text(doc)
ctx = description + labeled_examples + example
return (
ctx,
{
"fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource,
"fewshot_num": num_fewshot,
"ctx": ctx,
},
)
def get_logging_info(self):
return {
"fixed_answer_choice_list": self.prompt.get_fixed_answer_choices_list(),
"dataset_path": self.DATASET_PATH,
"dataset_name": self.DATASET_NAME,
"subset": self.SPLIT,
"prompt_name": self.prompt.get_name(),
"prompt_id": self.prompt.get_id(),
"prompt_jinja": self.prompt.jinja,
"prompt_original_task": self.prompt.metadata.original_task,
# Placeholder for comment in post-processing.
"comment": "",
}
class MultipleChoiceTask(Task):
def doc_to_target(self, doc):
......
......@@ -173,10 +173,6 @@ def evaluate(
# get lists of each type of request
for task_prompt_name, task in task_dict_items:
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
......@@ -188,7 +184,7 @@ def evaluate(
raise RuntimeError("Task has neither test_docs nor validation_docs")
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
task_docs = list(task_doc_func())
task_docs = list(enumerate(list(task_doc_func())))
rnd = random.Random()
rnd.seed(42)
rnd.shuffle(task_docs)
......@@ -199,14 +195,17 @@ def evaluate(
else ""
)
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
for doc_id, (original_doc_id, doc) in enumerate(
itertools.islice(task_docs, 0, limit)
):
if task.invalid_doc_for_prompt(doc):
continue
docs[(task_prompt_name, doc_id)] = doc
ctx = task.fewshot_context(
ctx, fewshotex_logging_info = task.fewshot_context(
doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description
)
fewshotex_logging_info["doc_id"] = original_doc_id
reqs = task.construct_requests(doc, ctx)
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
......@@ -215,7 +214,7 @@ def evaluate(
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id)
(i, task_prompt_name, doc, doc_id, fewshotex_logging_info)
)
# all responses for each (task, doc)
......@@ -234,33 +233,57 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
]
for resp, (i, task_prompt_name, doc, doc_id) in zip(
for resp, (i, task_prompt_name, doc, doc_id, fewshotex_logging_info) in zip(
resps, requests_origin[reqtype]
):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
process_res_queue[(task_prompt_name, doc_id)].append(
(i, resp, fewshotex_logging_info)
)
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
for (task_prompt_name, doc_id), requests in process_res_queue.items():
requests.sort(key=lambda x: x[0])
requests = [x[1] for x in requests]
examples = []
for (task_prompt_name, doc_id), per_doc_requests in process_res_queue.items():
per_doc_requests.sort(key=lambda x: x[0])
per_doc_results = [x[1] for x in per_doc_requests]
fewshot_logging_info = [x[2] for x in per_doc_requests][0]
task = task_dict[task_prompt_name]
doc = docs[(task_prompt_name, doc_id)]
metrics = task.process_results(doc, requests)
output = task.process_results(doc, per_doc_results)
if task.save_examples:
metrics, example = output
example.update(fewshot_logging_info)
example.update(task.get_logging_info())
examples.append(example)
else:
metrics = output
example = fewshot_logging_info
example.update(task.get_logging_info())
examples.append(example)
for metric, value in metrics.items():
vals[(task_prompt_name, metric)].append(value)
# aggregate results
metric_results = []
for (task_prompt_name, metric), items in vals.items():
task_name, prompt_name = task_prompt_name.split("+")
results[task_prompt_name]["task_name"] = task_name
results[task_prompt_name]["prompt_name"] = prompt_name
task = task_dict[task_prompt_name]
results[task_prompt_name][metric] = task.aggregation()[metric](items)
_metric_results = {
"task_name": task_name,
"prompt_name": prompt_name,
metric: task.aggregation()[metric](items),
**task.get_logging_info(),
}
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.metrics.stderr_for_metric(
......@@ -271,8 +294,18 @@ def evaluate(
)
if stderr is not None:
results[task_prompt_name][metric + "_stderr"] = stderr(items)
return {"results": dict(results), "versions": dict(versions)}
_metric_results[metric + "_stderr"] = stderr(items)
metric_results.append(_metric_results)
return {
# List of results that tracks the averages per model and prompt.
"results": metric_results,
"versions": dict(versions),
# List of all prompt x doc examples with additional information in it.
"examples": examples,
# Original results used for generating the table when running this file.
"table_results": dict(results),
}
def make_table(result_dict):
......@@ -293,7 +326,7 @@ def make_table(result_dict):
]
values = []
for k, dic in result_dict["results"].items():
for k, dic in result_dict["table_results"].items():
version = result_dict["versions"][k]
for m, v in dic.items():
if m.endswith("_stderr"):
......
......@@ -118,25 +118,18 @@ class CoQA(PromptSourceTask):
"""
target = self.doc_to_target(doc).strip()
pred = results[0].strip().split("\n")[0]
print("*" * 80)
print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(f"PRED: {pred} END PRED")
print("*" * 80)
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
# TODO: Add HF metrics mapped from promptsource metadata.
scores = self.compute_scores([target], pred)
return {
out = {
"f1": scores["f1"],
"em": scores["em"],
}
if self.save_examples:
example = {"target": target, "pred": pred}
return out, example
return out
def higher_is_better(self):
return {
"f1": True,
......
......@@ -9,27 +9,29 @@ logging.getLogger("openai").setLevel(logging.WARNING)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
parser.add_argument('--device', type=str, default=None)
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--check_integrity', action="store_true")
parser.add_argument("--model", required=True)
parser.add_argument("--model_args", default="")
parser.add_argument("--tasks", default="all_tasks")
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=int, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--description_dict_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
assert not args.provide_description # not implemented
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
print(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if args.tasks == "all_tasks":
task_names = tasks.ALL_TASKS
......@@ -38,7 +40,7 @@ def main():
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, 'r') as f:
with open(args.description_dict_path, "r") as f:
description_dict = json.load(f)
results = evaluator.simple_evaluate(
......@@ -51,11 +53,12 @@ def main():
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
check_integrity=args.check_integrity
check_integrity=args.check_integrity,
)
print(results)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment