Commit caac0843 authored by lintangsutawika's avatar lintangsutawika
Browse files

reformat

parent adbbfb44
......@@ -4,6 +4,7 @@ from typing import List
from lm_eval.api.instance import Instance
from datasets import Dataset
class Filter:
"""
Filter classes operate on a per-task level.
......
......@@ -78,7 +78,7 @@ class HFLM(LM):
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
cache_dir: Optional[Union[str,os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
parallelize: Optional[bool] = False,
......@@ -425,7 +425,11 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None, truncation=False
self,
strings: List[str],
padding_side="left",
left_truncate_len=None,
truncation=False,
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
......@@ -863,7 +867,9 @@ class HFLM(LM):
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len, truncation=self.truncation,
contexts,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
......
......@@ -15,4 +15,3 @@ metric_list:
higher_is_better: true
ignore_case: true
ignore_punctuation: true
import sklearn.metrics
def mean_3class_f1(predictions, references): # This is a passthrough function
string_label = ['entailment', 'contradiction', 'neutral']
string_label = ["entailment", "contradiction", "neutral"]
predictions = string_label.index(predictions[0])
references = string_label.index(references[0])
return (predictions, references)
def agg_mean_3class_f1(items):
predictions, references = zip(*items)
......
......@@ -3,16 +3,22 @@ import collections
import numpy as np
import sklearn.metrics
def f1(predictions, references): # This is a passthrough function
_prediction = predictions[0]
_reference = references[0].split("_")[-1]
string_label = ['False', 'True']
string_label = ["False", "True"]
reference = string_label.index(_reference)
prediction = string_label.index(_prediction) if _prediction in string_label else not bool(reference)
prediction = (
string_label.index(_prediction)
if _prediction in string_label
else not bool(reference)
)
return (prediction, reference)
def agg_f1(items):
predictions, references = zip(*items)
......@@ -25,9 +31,13 @@ def em(predictions, references): # This is a passthrough function
_prediction = predictions[0]
_group, _reference = references[0].split("_")
string_label = ['False', 'True']
string_label = ["False", "True"]
reference = string_label.index(_reference)
prediction = string_label.index(_prediction) if _prediction in string_label else not bool(reference)
prediction = (
string_label.index(_prediction)
if _prediction in string_label
else not bool(reference)
)
return (_group, prediction, reference)
......@@ -44,4 +54,4 @@ def agg_em(items):
score = float(np.array_equal(targets, predictions))
group_scores.append(score)
return np.mean(group_scores)
\ No newline at end of file
return np.mean(group_scores)
......@@ -11,22 +11,23 @@ from lm_eval.api.metrics import metric_max_over_ground_truths
def doc_to_text(doc):
passage = doc['passage']
passage = re.sub(r'(\.|\?|\!|\"|\')\n@highlight\n', r'\1 ', passage)
passage = re.sub(r'\n@highlight\n', '. ', passage)
return " ".join([
"record query:",
doc['query'],
"entities:",
", ".join(doc['entities']),
"passage:",
passage
])
passage = doc["passage"]
passage = re.sub(r"(\.|\?|\!|\"|\')\n@highlight\n", r"\1 ", passage)
passage = re.sub(r"\n@highlight\n", ". ", passage)
return " ".join(
[
"record query:",
doc["query"],
"entities:",
", ".join(doc["entities"]),
"passage:",
passage,
]
)
def process_docs(dataset):
def split_answers(doc):
split_doc = {
**{k: [] for k in doc.keys()},
......@@ -37,54 +38,21 @@ def process_docs(dataset):
for key in split_doc.keys():
if key in doc:
split_doc[key].append(doc[key])
split_doc["answers"].append(answer)
return split_doc
dataset = dataset.map(split_answers)
new_dataset = {}
for key in dataset.features.keys():
new_dataset[key] = [x for row in dataset[key] for x in row]
return Dataset.from_dict(new_dataset)
def deduplicate_metric(metric_fn,
group_key: str = "group",
value_key: str = "value"):
"""Returns a metric that only considers one example per group.
Useful for things like ReCoRD where inputs may be replicated during training
to handle multiple labels, but where at eval we only want a single copy of
each example.
Args:
metric_fn: function, the metric to compute on the unique examples.
group_key: the key for the grouping value in the target dictionary.
value_key: the key for the value in the dictionaries.
Returns:
A metric function that deduplicated based on the grouping key before
returning a metric.
"""
def _deduplicated_metric(targets, predictions):
"""Deduplicate targets and predictions and pass that to the metric fn."""
processed_groups = set()
deduplicated_targets = []
deduplicated_predictions = []
for targ, pred in zip(targets, predictions):
group = targ[group_key]
if group in processed_groups:
continue
processed_groups.add(group)
deduplicated_targets.append(targ[value_key])
deduplicated_predictions.append(pred[value_key])
return metric_fn(deduplicated_targets, deduplicated_predictions)
return _deduplicated_metric
return Dataset.from_dict(new_dataset)
def normalize_squad(answer):
"""Normalization used in official SQuAD evaluation script."""
def _normalize_answer(text, punc_chars, punc_repl):
"""Lower text and remove punctuation, articles and extra whitespace."""
......@@ -107,39 +75,43 @@ def normalize_squad(answer):
return _normalize_answer(answer, punc_chars=string.punctuation, punc_repl="")
def em(predictions, references): # This is a passthrough function
return (predictions[0], references[0])
def f1(predictions, references): # This is a passthrough function
return (predictions[0], references[0])
def squad_em_agg(items):
def _exact_match_score(target, prediction):
def squad_em_agg(items):
def _exact_match_score(prediction, target):
return target == prediction
grouped_values = collections.defaultdict(lambda: ([], []))
for prediction, reference in items:
group, reference = reference.split("_")
# if group not in grouped_values:
grouped_values[group][0].append(normalize_squad(prediction))
grouped_values[group][1].append(normalize_squad(reference))
print(grouped_values)
import sys; sys.exit()
em = np.mean([
metric_max_over_ground_truths(_exact_match_score, t, p)
for p, t in zip(predictions, targets)
])
return em
em = []
for group in grouped_values.keys():
predictions, targets = grouped_values[group]
for p in predictions:
em.append(metric_max_over_ground_truths(_exact_match_score, p, targets))
def squad_f1_agg(items):
return np.mean(em)
def _f1_score(target, prediction):
def squad_f1_agg(items):
def _f1_score(prediction, target):
"""Computes token f1 score for a single target and prediction."""
prediction_tokens = prediction.split()
target_tokens = target.split()
common = (collections.Counter(prediction_tokens) &
collections.Counter(target_tokens))
common = collections.Counter(prediction_tokens) & collections.Counter(
target_tokens
)
num_same = sum(common.values())
if num_same == 0:
return 0
......@@ -148,12 +120,16 @@ def squad_f1_agg(items):
f1 = (2 * precision * recall) / (precision + recall)
return f1
predictions, targets = zip(*items)
targets = [normalize_squad(t) for t in targets]
predictions = [normalize_squad(p) for p in predictions]
grouped_values = collections.defaultdict(lambda: ([], []))
for prediction, reference in items:
group, reference = reference.split("_")
if group not in grouped_values:
grouped_values[group][0].append(normalize_squad(prediction))
grouped_values[group][1].append(normalize_squad(reference))
f1 = []
for group in grouped_values.keys():
p, t = grouped_values[group]
f1.append(metric_max_over_ground_truths(_f1_score, p[0], t))
f1 = np.mean([
metric_max_over_ground_truths(_f1_score, t, p)
for p, t in zip(predictions, targets)
])
return f1
return np.mean(f1)
......@@ -15,4 +15,4 @@ metric_list:
filter_list:
- name: "wsc_postprocessor"
filter:
- function: !function t5_utils.WSCPostprocess
\ No newline at end of file
- function: !function t5_utils.WSCPostprocess
import re
from lm_eval.api.filter import Filter
def doc_to_text(x):
text = re.sub(r" X ", " *"+x["span2_text"]+"* ", _wsc_inputs(x))
return "wsc: "+text
text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
return "wsc: " + text
def _wsc_inputs(x):
words = x['text'].split(" ")
words = x["text"].split(" ")
# We would need some special logic to handle the case where the pronoun is the
# first or last word in the text. None of the examples in WSC seem to have
# this, so we are ignoring these cases.
assert x['span2_index'] > 0
assert x['span2_index'] < len(words)
pronoun_index = x['span2_index']
assert x["span2_index"] > 0
assert x["span2_index"] < len(words)
pronoun_index = x["span2_index"]
def create_input():
assert words[pronoun_index] == x['span2_text']
return " ".join([
" ".join(
words[:pronoun_index]
),
'X',
" ".join(
words[pronoun_index + 1:]
),
])
assert words[pronoun_index] == x["span2_text"]
return " ".join(
[
" ".join(words[:pronoun_index]),
"X",
" ".join(words[pronoun_index + 1 :]),
]
)
# Handle some special cases.
if x['text'] == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. \"Good for him,\" he said. ':
if (
x["text"]
== 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
):
return (
'The boy continued to whip the pony , and eventually the pony threw '
"The boy continued to whip the pony , and eventually the pony threw "
'him over. John laughed out quite loud. "Good for X ," he said.'
)
# Using the span2_index, we get 'use' instead of 'it'.
if x['text'] == 'When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?':
if (
x["text"]
== "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
):
return (
'When they had eventually calmed down a bit , and had gotten home, '
'Mr. Farley put the magic pebble in an iron safe . Some day they might '
'want to use X , but really for now, what more could they wish for?'
"When they had eventually calmed down a bit , and had gotten home, "
"Mr. Farley put the magic pebble in an iron safe . Some day they might "
"want to use X , but really for now, what more could they wish for?"
)
return create_input()
class WSCPostprocess(Filter):
def __init__(self, **kwargs):
self.determiners = {
"a", "an", "few", "her", "his", "each", "every", "many", "much", "my",
"our", "some", "that", "the", "their", "these", "this", "those", "which",
"whose", "your"
"a",
"an",
"few",
"her",
"his",
"each",
"every",
"many",
"much",
"my",
"our",
"some",
"that",
"the",
"their",
"these",
"this",
"those",
"which",
"whose",
"your",
}
def clean(self, s):
......@@ -81,7 +103,8 @@ class WSCPostprocess(Filter):
# Handle cases where the prediction is "fuzzy bunny" and the referent is
# "bunny".
predicted_referent = prediction_words.issubset(
referent_words) or referent_words.issubset(prediction_words)
referent_words
) or referent_words.issubset(prediction_words)
filtered_resps.append(predicted_referent)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment