Commit 0cdcc989 authored by jon-tow's avatar jon-tow Committed by cjlovering
Browse files

Turned off generation tasks for now. Changed process to look at the metrics....

Turned off generation tasks for now. Changed process to look at the metrics. Only accuracy implemented.
parent b1a3c6e3
import abc import abc
from typing import Iterable from typing import Iterable
import promptsource import promptsource
import numpy as np import numpy as np
import random import random
import re import re
...@@ -642,6 +642,12 @@ class PromptSourceTask(Task): ...@@ -642,6 +642,12 @@ class PromptSourceTask(Task):
def eos_token(self): def eos_token(self):
raise NotImplementedError() raise NotImplementedError()
def is_generation_task(self):
return (
"BLEU" in self.prompt.metadata.metrics
or "ROUGE" in self.prompt.metadata.metrics
)
def doc_to_target(self, doc): def doc_to_target(self, doc):
_, target = self.prompt.apply(doc) _, target = self.prompt.apply(doc)
return f" {target}" return f" {target}"
...@@ -663,11 +669,19 @@ class PromptSourceTask(Task): ...@@ -663,11 +669,19 @@ class PromptSourceTask(Task):
""" """
_requests = [] _requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
# We take a present answer_choices list to mean that we should apply the supplied
# metrics (hardcoded or accuracy atm) to the ranked choices. Otherwise, assume generation.
# Above we do something similar, but rely on the metrics requested (BLEU, ROUGE indicating generation).
if answer_choices_list: if answer_choices_list:
assert (
not self.is_generation_task()
), f"We expect this to be a ranked choice task; double check please."
for answer_choice in answer_choices_list: for answer_choice in answer_choices_list:
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
assert False
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.eos_token()]) cont_request = rf.greedy_until(ctx, [self.eos_token()])
_requests.append(cont_request) _requests.append(cont_request)
...@@ -690,27 +704,35 @@ class PromptSourceTask(Task): ...@@ -690,27 +704,35 @@ class PromptSourceTask(Task):
target = self.doc_to_target(doc).strip() target = self.doc_to_target(doc).strip()
answer_choices_list = self.prompt.get_answer_choices_list(doc) answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list: if answer_choices_list:
assert (
not self.is_generation_task()
), f"We expect this to be a ranked choice task; double check please."
pred = answer_choices_list[np.argmax(results)] pred = answer_choices_list[np.argmax(results)]
return { out = {}
"acc": pred == target if "Accuracy" in self.prompt.metadata.metrics:
} out["acc"] = pred == target
# TODO: Add metrics here.
return out
else: else:
continuation = results raise NotImplementedError("Generation is not implemented yet.")
raise NotImplementedError()
# Map metric name to HF metric. # Map metric name to HF metric.
# TODO(Albert): What is Other? # TODO(Albert): What is Other?
#metric_names = prompt.metadata.metrics # metric_names = prompt.metadata.metrics
def higher_is_better(self): def higher_is_better(self):
return { out = {}
"acc": True if "Accuracy" in self.prompt.metadata.metrics:
} out["acc"] = True
return out
def aggregation(self): def aggregation(self):
return { out = {}
"acc": mean, if "Accuracy" in self.prompt.metadata.metrics:
} out["acc"] = mean
return out
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
......
...@@ -170,8 +170,10 @@ def evaluate( ...@@ -170,8 +170,10 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_prompt_name, task in task_dict_items: for task_prompt_name, task in task_dict_items:
print(f"TASK PROMPT NAME: {task_prompt_name}") if task.is_generation_task():
print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
continue
versions[task_prompt_name] = task.VERSION versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable # default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
...@@ -206,7 +208,9 @@ def evaluate( ...@@ -206,7 +208,9 @@ def evaluate(
requests[req.request_type].append(req) requests[req.request_type].append(req)
# i: index in requests for a single task instance # i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs` # doc_id: unique id that we can get back to a doc using `docs`
requests_origin[req.request_type].append((i, task_prompt_name, doc, doc_id)) requests_origin[req.request_type].append(
(i, task_prompt_name, doc, doc_id)
)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
...@@ -224,7 +228,9 @@ def evaluate( ...@@ -224,7 +228,9 @@ def evaluate(
x if req.index is None else x[req.index] for x, req in zip(resps, reqs) x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
] ]
for resp, (i, task_prompt_name, doc, doc_id) in zip(resps, requests_origin[reqtype]): for resp, (i, task_prompt_name, doc, doc_id) in zip(
resps, requests_origin[reqtype]
):
process_res_queue[(task_prompt_name, doc_id)].append((i, resp)) process_res_queue[(task_prompt_name, doc_id)].append((i, resp))
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
......
...@@ -52,6 +52,7 @@ from . import blimp ...@@ -52,6 +52,7 @@ from . import blimp
from . import asdiv from . import asdiv
from . import gsm8k from . import gsm8k
from . import storycloze from . import storycloze
from . import hans
# from . import e2e_nlg_cleaned # from . import e2e_nlg_cleaned
...@@ -146,6 +147,7 @@ TASK_REGISTRY = { ...@@ -146,6 +147,7 @@ TASK_REGISTRY = {
"anli_r1": anli.ANLIRound1, "anli_r1": anli.ANLIRound1,
"anli_r2": anli.ANLIRound2, "anli_r2": anli.ANLIRound2,
"anli_r3": anli.ANLIRound3, "anli_r3": anli.ANLIRound3,
"hans": hans.HANS,
"ethics_cm": hendrycks_ethics.EthicsCM, "ethics_cm": hendrycks_ethics.EthicsCM,
"ethics_deontology": hendrycks_ethics.EthicsDeontology, "ethics_deontology": hendrycks_ethics.EthicsDeontology,
"ethics_justice": hendrycks_ethics.EthicsJustice, "ethics_justice": hendrycks_ethics.EthicsJustice,
......
...@@ -67,17 +67,17 @@ class CoLA(PromptSourceTask): ...@@ -67,17 +67,17 @@ class CoLA(PromptSourceTask):
def validation_docs(self): def validation_docs(self):
return self.dataset["validation"] return self.dataset["validation"]
def process_results(self, doc, results): # def process_results(self, doc, results):
answer_choices_list = self.prompt.get_answer_choices_list(doc) # answer_choices_list = self.prompt.get_answer_choices_list(doc)
pred = np.argmax(results) # pred = np.argmax(results)
target = answer_choices_list.index(self.doc_to_target(doc).strip()) # target = answer_choices_list.index(self.doc_to_target(doc).strip())
return {"mcc": (target, pred)} # return {"mcc": (target, pred)}
def higher_is_better(self): # def higher_is_better(self):
return {"mcc": True} # return {"mcc": True}
def aggregation(self): # def aggregation(self):
return {"mcc": matthews_corrcoef} # return {"mcc": matthews_corrcoef}
class SST(PromptSourceTask): class SST(PromptSourceTask):
......
"""
Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference
https://arxiv.org/abs/1902.01007
A controlled evaluation set called HANS (Heuristic Analysis for NLI Systems),
which contains many examples where the heuristics fail.
Homepage: https://github.com/tommccoy1/hans
"""
from lm_eval.base import PromptSourceTask
_CITATION = """
@inproceedings{mccoy-etal-2019-right,
title = "Right for the Wrong Reasons: Diagnosing Syntactic Heuristics in Natural Language Inference",
author = "McCoy, Tom and
Pavlick, Ellie and
Linzen, Tal",
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2019",
address = "Florence, Italy",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/P19-1334",
doi = "10.18653/v1/P19-1334",
pages = "3428--3448",
abstract = "A machine learning system can score well on a given test set by relying on heuristics that are effective for frequent example types but break down in more challenging cases. We study this issue within natural language inference (NLI), the task of determining whether one sentence entails another. We hypothesize that statistical NLI models may adopt three fallible syntactic heuristics: the lexical overlap heuristic, the subsequence heuristic, and the constituent heuristic. To determine whether models have adopted these heuristics, we introduce a controlled evaluation set called HANS (Heuristic Analysis for NLI Systems), which contains many examples where the heuristics fail. We find that models trained on MNLI, including BERT, a state-of-the-art model, perform very poorly on HANS, suggesting that they have indeed adopted these heuristics. We conclude that there is substantial room for improvement in NLI systems, and that the HANS dataset can motivate and measure progress in this area.",
}
"""
class HANS(PromptSourceTask):
VERSION = 0
DATASET_PATH = "hans"
DATASET_NAME = None
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return False
def training_docs(self):
if self.has_training_docs():
# We cache training documents in `self._training_docs` for faster
# few-shot processing. If the data is too large to fit in memory,
# return the training data as a generator instead of a list.
if self._training_docs is None:
self._training_docs = list(self.dataset["train"])
return self._training_docs
def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]
def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment