Unverified Commit fce17ee1 authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #12 from tttyuntian/master

Added null prompt support for T5 & Added BLIMP task template
parents 20820c3c 3ee4da8e
......@@ -100,6 +100,14 @@ class T5LM(BaseLM):
inputs, targets = zip(*chunk)
# Fill in empty encoder inputs with eos_token
inputs = (
f"{self.eot_token}"
if len(input_) == 0
else input_
for input_ in inputs
)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
......@@ -123,7 +131,7 @@ class T5LM(BaseLM):
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
outputs = self._model_call(inputs_tok, targets_tok)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
......
......@@ -10,7 +10,7 @@ grammars.
Homepage: https://github.com/alexwarstadt/blimp
"""
from lm_eval.base import rf, Task
from lm_eval.base import rf, PromptSourceTask
from lm_eval.metrics import mean
......@@ -31,7 +31,7 @@ _CITATION = """
"""
class BlimpTask(Task):
class BlimpTask(PromptSourceTask):
VERSION = 0
DATASET_PATH = "blimp"
......@@ -50,58 +50,6 @@ class BlimpTask(Task):
# trained on this data.
return self.dataset["train"]
def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None):
assert num_fewshot == 0
assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if provide_description is not None:
# nudge people to not specify it at all
print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict")
return ""
def doc_to_text(self, doc):
# this method is invoked by tests only
return ""
def doc_to_target(self, doc):
# this method is invoked by tests only
return ""
def construct_requests(self, doc, ctx):
assert not ctx
# Calculate the loglikelihood for the good and the bad sentence.
# Note that loglikelihood translates the "" prefix to the "<|endoftext|>" token
return [
rf.loglikelihood("", doc["sentence_good"]),
rf.loglikelihood("", doc["sentence_bad"]),
]
def process_results(self, doc, results):
likelihood1, likelihood2 = results
# the model got this case right iff the good sentence scored higher than the bad sentence
acc = 1.0 if likelihood1 > likelihood2 else 0.0
return {
"acc": acc,
}
def higher_is_better(self):
return {
"acc": True,
}
def aggregation(self):
return {
"acc": mean,
}
class BlimpAdjunctIsland(BlimpTask):
DATASET_NAME = "adjunct_island"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment