Unverified Commit a0f1cacd authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

`Filter` docs not offset by `doc_id` (#1349)

* get `doc` from instance

* acceletate bugfix: get ground doc from instance

* convert filter to `process_result`

* get docs from instances in `FilterEnsemble`

* rename

* nit

* better looping

* fix typehint
parent 34cded30
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from datasets import Dataset
from lm_eval.api.instance import Instance
class Filter:
class Filter(ABC):
"""
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
......@@ -20,6 +19,7 @@ class Filter:
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
@abstractmethod
def apply(self, resps, docs):
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
......@@ -42,10 +42,10 @@ class FilterEnsemble:
name: str
filters: List[Filter]
def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
resps = [
inst.resps for inst in instances
] # operate just on the model responses
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
for f in self.filters:
# apply filters in sequence
resps = f.apply(resps, docs)
......
......@@ -490,7 +490,7 @@ class Task(abc.ABC):
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, None)
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
......@@ -626,16 +626,15 @@ class ConfigurableTask(Task):
if self.config.filter_list is not None:
self._filters = []
for filter_config in self.config.filter_list:
for filter_pipeline in filter_config:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
......@@ -813,7 +812,7 @@ class ConfigurableTask(Task):
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances, self.task_docs)
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
......
from typing import List
from lm_eval.api.filter import FilterEnsemble
from . import selection
from . import extraction
......@@ -27,7 +29,9 @@ def get_filter(filter_name):
return filter_name
def build_filter_ensemble(filter_name, components):
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
"""
......
......@@ -7,6 +7,7 @@ training_split: train
validation_split: validation
output_type: generate_until
doc_to_text: !function "t5_utils.doc_to_text"
process_results: !function "t5_utils.process_results"
doc_to_target: label
generation_kwargs:
until:
......@@ -15,9 +16,5 @@ metric_list:
- metric: accuracy
aggregation: mean
higher_is_better: true
filter_list:
- name: "wsc_postprocessor"
filter:
- function: !function t5_utils.WSCPostprocess
metadata:
version: 0.0
version: 1.0
import re
from lm_eval.api.filter import Filter
from typing import List
def doc_to_text(x):
text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
......@@ -24,14 +23,14 @@ def _wsc_inputs(x):
[
" ".join(words[:pronoun_index]),
"X",
" ".join(words[pronoun_index + 1 :]),
" ".join(words[pronoun_index + 1:]),
]
)
# Handle some special cases.
if (
x["text"]
== 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
x["text"]
== 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
):
return (
"The boy continued to whip the pony , and eventually the pony threw "
......@@ -40,8 +39,8 @@ def _wsc_inputs(x):
# Using the span2_index, we get 'use' instead of 'it'.
if (
x["text"]
== "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
x["text"]
== "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
):
return (
"When they had eventually calmed down a bit , and had gotten home, "
......@@ -52,56 +51,53 @@ def _wsc_inputs(x):
return create_input()
class WSCPostprocess(Filter):
def __init__(self, **kwargs):
self.determiners = {
"a",
"an",
"few",
"her",
"his",
"each",
"every",
"many",
"much",
"my",
"our",
"some",
"that",
"the",
"their",
"these",
"this",
"those",
"which",
"whose",
"your",
}
def clean(self, s):
"""Ignore capitalization and determiners."""
s = s.strip().lower()
return " ".join([w for w in s.split(" ") if w not in self.determiners])
def apply(self, resps, docs):
filtered_resps = []
for prediction, reference in zip(*(resps, docs["span1_text"])):
prediction = self.clean(prediction[0])
reference = self.clean(reference)
if ("'" in prediction) != ("'" in reference):
# referent is "Bob's hat" as predicting the referent.
predicted_referent = False
else:
prediction_words = set(prediction.split(" "))
referent_words = set(reference.split(" "))
# Handle cases where the prediction is "fuzzy bunny" and the referent is
# "bunny".
predicted_referent = prediction_words.issubset(
referent_words
) or referent_words.issubset(prediction_words)
filtered_resps.append(predicted_referent)
return filtered_resps
DETERMINERS = {
"a",
"an",
"few",
"her",
"his",
"each",
"every",
"many",
"much",
"my",
"our",
"some",
"that",
"the",
"their",
"these",
"this",
"those",
"which",
"whose",
"your",
}
def clean(s: str) -> str:
"""Ignore capitalization and determiners."""
s = s.strip().lower()
return " ".join([w for w in s.split(" ") if w not in DETERMINERS])
def process_results(docs: dict, resps: List):
prediction = clean(resps[0])
reference = clean(docs["span1_text"])
if ("'" in prediction) != ("'" in reference):
# referent is "Bob's hat" as predicting the referent.
predicted_referent = False
else:
prediction_words = set(prediction.split(" "))
referent_words = set(reference.split(" "))
# Handle cases where the prediction is "fuzzy bunny" and the referent is
# "bunny".
predicted_referent = prediction_words.issubset(
referent_words
) or referent_words.issubset(prediction_words)
acc = 1.0 if predicted_referent == docs["label"] else 0.0
return {"accuracy": acc}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment