"examples/textual_inversion/textual_inversion_sdxl.py" did not exist on "e9aa0925a8e5783814cd1e0da6f601fd3eb88571"
Commit b4ad893c authored by ken's avatar ken
Browse files

Merge master

parents 8c83a821 20820c3c
......@@ -32,7 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install flake8 pytest pytest-cov
pip install -e .
pip install -e .[dev]
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
......
......@@ -11,7 +11,7 @@ If you haven't already, go ahead and fork the main repo, clone it, create a bran
git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
cd lm-evaluation-harness
git checkout -b <task-name>
pip install -r requirements.txt
pip install -e ".[dev]"
```
## Creating Your Task File
......
......@@ -121,6 +121,11 @@ class LM(abc.ABC):
class BaseLM(LM):
@property
@abstractmethod
def eot_token(self):
pass
@property
@abstractmethod
def eot_token_id(self):
......@@ -354,8 +359,15 @@ class BaseLM(LM):
isinstance(max_generation_length, int) or max_generation_length is None
)
until = [stopping_criteria]
if stopping_criteria is None:
until = [self.eot_token]
else:
until = [stopping_criteria]
primary_until = self.tok_encode(until[0])
if len(primary_until) == 0:
primary_until = torch.tensor([self.eot_token_id])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
......@@ -633,14 +645,18 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator = "\n###\n"
labeled_examples = (
"\n\n".join(
example_separator.join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ "\n\n"
+ example_separator
)
example = self.doc_to_text(doc)
......@@ -654,11 +670,21 @@ class PromptSourceTask(Task):
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
"""
CONFIGURED_PS_METRICS = set(["Accuracy", "BLEU", "ROUGE"])
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, prompt=None):
CONFIGURED_RANKED_CHOICE_PS_METRICS = set(["Accuracy"])
CONFIGURED_GENERATION_PS_METRICS = set(["BLEU", "ROUGE"])
SPLIT = None
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
prompt=None,
save_examples=True,
):
super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt
self.save_examples = save_examples
def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end.
......@@ -752,12 +778,11 @@ class PromptSourceTask(Task):
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
metric in self.CONFIGURED_RANKED_CHOICE_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = pred == target
# TODO: Add metrics here.
return out
else:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
......@@ -765,11 +790,11 @@ class PromptSourceTask(Task):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
metric in self.CONFIGURED_GENERATION_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "BLEU":
out["bleu"] = (target, pred)
if metric == "ROUGE":
elif metric == "ROUGE":
# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
......@@ -778,15 +803,21 @@ class PromptSourceTask(Task):
rouge_scores = utils.flatten(rouge_scores)
# Merge all the rouge-type scores into the `out` dict.
out = {**out, **rouge_scores}
print(out)
return out
# TODO: Wrap process results s.t. override impl do not
# override the save examples.
if self.save_examples:
example = {
"pred": pred,
"target": target,
"answer_choices_list": answer_choices_list,
}
return out, example
return out
def higher_is_better(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = True
if metric == "BLEU":
......@@ -813,9 +844,6 @@ class PromptSourceTask(Task):
def aggregation(self):
out = {}
for metric in self.prompt.metadata.metrics:
assert (
metric in self.CONFIGURED_PS_METRICS
), "Unexpected metric. Add it, or use a task-specific solution."
if metric == "Accuracy":
out["acc"] = mean
if metric == "BLEU":
......@@ -839,6 +867,125 @@ class PromptSourceTask(Task):
out["rougeLsum_fmeasure"] = mean
return out
def fewshot_examples(self, k, rnd):
if self._training_docs is None:
self._training_docs = list(self.training_docs())
return self._get_fewshot_examples(self._training_docs, k, rnd)
def _get_fewshot_examples(self, docs, k, rnd):
fewshot_idx = rnd.sample(list(np.arange(len(docs))), k)
return [docs[idx] for idx in fewshot_idx], [int(idx) for idx in fewshot_idx]
@utils.positional_deprecated
def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description = description + "\n\n" if description else ""
if num_fewshot == 0:
labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train"
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)
if self.has_validation_docs():
fewshotsource = "val"
elif self.test_docs():
fewshotsource = "test"
fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
)
fewshotex, fewshotidx = [
(shot, idx)
for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc
]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = (
fewshotex[:num_fewshot],
fewshotidx[:num_fewshot],
)
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator = "\n###\n"
labeled_examples = (
example_separator.join(
[
self.doc_to_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ example_separator
)
example = self.doc_to_text(doc)
ctx = description + labeled_examples + example
return (
ctx,
{
"fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource,
"fewshot_num": num_fewshot,
"ctx": ctx,
},
)
def get_logging_info(self):
return {
"fixed_answer_choice_list": self.prompt.get_fixed_answer_choices_list(),
"dataset_path": self.DATASET_PATH,
"dataset_name": self.DATASET_NAME,
"subset": self.SPLIT,
"prompt_name": self.prompt.get_name(),
"prompt_id": self.prompt.get_id(),
"prompt_jinja": self.prompt.jinja,
"prompt_original_task": self.prompt.metadata.original_task,
# Placeholder for comment in post-processing.
"comment": "",
}
class MultipleChoiceTask(Task):
def doc_to_target(self, doc):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment