Unverified Commit a0f1cacd authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

`Filter` docs not offset by `doc_id` (#1349)

* get `doc` from instance

* acceletate bugfix: get ground doc from instance

* convert filter to `process_result`

* get docs from instances in `FilterEnsemble`

* rename

* nit

* better looping

* fix typehint
parent 34cded30
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
from datasets import Dataset
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
class Filter: class Filter(ABC):
""" """
Filter classes operate on a per-task level. Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`) They take all model outputs (`instance.resps` for all `task.instances`)
...@@ -20,6 +19,7 @@ class Filter: ...@@ -20,6 +19,7 @@ class Filter:
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
@abstractmethod
def apply(self, resps, docs): def apply(self, resps, docs):
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
...@@ -42,10 +42,10 @@ class FilterEnsemble: ...@@ -42,10 +42,10 @@ class FilterEnsemble:
name: str name: str
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance], docs: List[Dataset]) -> None: def apply(self, instances: List[Instance]) -> None:
resps = [ resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
inst.resps for inst in instances resps, docs = list(resps), list(docs)
] # operate just on the model responses
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
resps = f.apply(resps, docs) resps = f.apply(resps, docs)
......
...@@ -490,7 +490,7 @@ class Task(abc.ABC): ...@@ -490,7 +490,7 @@ class Task(abc.ABC):
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances, None) f.apply(self._instances)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning("No filter defined, passing through instances")
return self._instances return self._instances
...@@ -626,16 +626,15 @@ class ConfigurableTask(Task): ...@@ -626,16 +626,15 @@ class ConfigurableTask(Task):
if self.config.filter_list is not None: if self.config.filter_list is not None:
self._filters = [] self._filters = []
for filter_config in self.config.filter_list: for filter_config in self.config.filter_list:
for filter_pipeline in filter_config: filter_name = filter_config["name"]
filter_name = filter_config["name"] filter_functions = filter_config["filter"]
filter_functions = filter_config["filter"] components = []
components = [] for function in filter_functions:
for function in filter_functions: kwargs = {
kwargs = { key: function[key] for key in function if key != "function"
key: function[key] for key in function if key != "function" }
} components.append([function["function"], kwargs])
components.append([function["function"], kwargs]) filter_pipeline = build_filter_ensemble(filter_name, components)
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
else: else:
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [["take_first", None]])]
...@@ -813,7 +812,7 @@ class ConfigurableTask(Task): ...@@ -813,7 +812,7 @@ class ConfigurableTask(Task):
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances, self.task_docs) f.apply(self._instances)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning("No filter defined, passing through instances")
return self._instances return self._instances
......
from typing import List
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from . import selection from . import selection
from . import extraction from . import extraction
...@@ -27,7 +29,9 @@ def get_filter(filter_name): ...@@ -27,7 +29,9 @@ def get_filter(filter_name):
return filter_name return filter_name
def build_filter_ensemble(filter_name, components): def build_filter_ensemble(
filter_name: str, components: List[List[str]]
) -> FilterEnsemble:
""" """
Create a filtering pipeline. Create a filtering pipeline.
""" """
......
...@@ -7,6 +7,7 @@ training_split: train ...@@ -7,6 +7,7 @@ training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: !function "t5_utils.doc_to_text" doc_to_text: !function "t5_utils.doc_to_text"
process_results: !function "t5_utils.process_results"
doc_to_target: label doc_to_target: label
generation_kwargs: generation_kwargs:
until: until:
...@@ -15,9 +16,5 @@ metric_list: ...@@ -15,9 +16,5 @@ metric_list:
- metric: accuracy - metric: accuracy
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
filter_list:
- name: "wsc_postprocessor"
filter:
- function: !function t5_utils.WSCPostprocess
metadata: metadata:
version: 0.0 version: 1.0
import re import re
from lm_eval.api.filter import Filter from typing import List
def doc_to_text(x): def doc_to_text(x):
text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x)) text = re.sub(r" X ", " *" + x["span2_text"] + "* ", _wsc_inputs(x))
...@@ -24,14 +23,14 @@ def _wsc_inputs(x): ...@@ -24,14 +23,14 @@ def _wsc_inputs(x):
[ [
" ".join(words[:pronoun_index]), " ".join(words[:pronoun_index]),
"X", "X",
" ".join(words[pronoun_index + 1 :]), " ".join(words[pronoun_index + 1:]),
] ]
) )
# Handle some special cases. # Handle some special cases.
if ( if (
x["text"] x["text"]
== 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. ' == 'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for him," he said. '
): ):
return ( return (
"The boy continued to whip the pony , and eventually the pony threw " "The boy continued to whip the pony , and eventually the pony threw "
...@@ -40,8 +39,8 @@ def _wsc_inputs(x): ...@@ -40,8 +39,8 @@ def _wsc_inputs(x):
# Using the span2_index, we get 'use' instead of 'it'. # Using the span2_index, we get 'use' instead of 'it'.
if ( if (
x["text"] x["text"]
== "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?" == "When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?"
): ):
return ( return (
"When they had eventually calmed down a bit , and had gotten home, " "When they had eventually calmed down a bit , and had gotten home, "
...@@ -52,56 +51,53 @@ def _wsc_inputs(x): ...@@ -52,56 +51,53 @@ def _wsc_inputs(x):
return create_input() return create_input()
class WSCPostprocess(Filter): DETERMINERS = {
def __init__(self, **kwargs): "a",
self.determiners = { "an",
"a", "few",
"an", "her",
"few", "his",
"her", "each",
"his", "every",
"each", "many",
"every", "much",
"many", "my",
"much", "our",
"my", "some",
"our", "that",
"some", "the",
"that", "their",
"the", "these",
"their", "this",
"these", "those",
"this", "which",
"those", "whose",
"which", "your",
"whose", }
"your",
}
def clean(s: str) -> str:
def clean(self, s): """Ignore capitalization and determiners."""
"""Ignore capitalization and determiners.""" s = s.strip().lower()
s = s.strip().lower() return " ".join([w for w in s.split(" ") if w not in DETERMINERS])
return " ".join([w for w in s.split(" ") if w not in self.determiners])
def apply(self, resps, docs): def process_results(docs: dict, resps: List):
filtered_resps = [] prediction = clean(resps[0])
for prediction, reference in zip(*(resps, docs["span1_text"])): reference = clean(docs["span1_text"])
prediction = self.clean(prediction[0])
reference = self.clean(reference) if ("'" in prediction) != ("'" in reference):
# referent is "Bob's hat" as predicting the referent.
if ("'" in prediction) != ("'" in reference): predicted_referent = False
# referent is "Bob's hat" as predicting the referent. else:
predicted_referent = False prediction_words = set(prediction.split(" "))
else: referent_words = set(reference.split(" "))
prediction_words = set(prediction.split(" "))
referent_words = set(reference.split(" ")) # Handle cases where the prediction is "fuzzy bunny" and the referent is
# "bunny".
# Handle cases where the prediction is "fuzzy bunny" and the referent is predicted_referent = prediction_words.issubset(
# "bunny". referent_words
predicted_referent = prediction_words.issubset( ) or referent_words.issubset(prediction_words)
referent_words
) or referent_words.issubset(prediction_words) acc = 1.0 if predicted_referent == docs["label"] else 0.0
return {"accuracy": acc}
filtered_resps.append(predicted_referent)
return filtered_resps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment