Commit d82bf9d9 authored by JessicaOjo's avatar JessicaOjo
Browse files

add verbalizer for xnli

parent 87232874
......@@ -49,6 +49,43 @@ class RegexFilter(Filter):
return filtered_resps
@register_filter("verbalizer")
class VerbalizerFilter(Filter):
""" """
def __init__(
self,
verbalizer_dict: dict,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.verbalizer_dict = verbalizer_dict
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def verbalize(value):
for key, values in self.verbalizer_dict.items():
for v in values:
if v in value:
return key
return value
def filter_value(inst):
filtered = []
for resp in inst:
match = verbalize(resp.lower())
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_value(x), resps)
return filtered_resps
@register_filter("regex-numbers")
class RegexNumberFilter(Filter):
""" """
......
......@@ -16,6 +16,14 @@ doc_to_choice:
- "contradiction"
should_decontaminate: true
doc_to_decontamination_query: premise
filter_list:
- name: "verbalizer_extract"
filter:
- function: verbalizer
verbalizer_dict: {
"entailment": ['encouragement', 'entitlement', 'entails', 'entailed', 'entailment'],
"contradiction": ['contradictory', 'contradicts', 'contradiction'],
"neutral": ['neutral']}
metric_list:
- metric: f1
aggregation: !function utils.weighted_f1_score
......
......@@ -15,6 +15,14 @@ doc_to_choice:
- "contradiction"
should_decontaminate: true
doc_to_decontamination_query: premise
filter_list:
- name: "verbalizer_extract"
filter:
- function: verbalizer
verbalizer_dict: {
"entailment": ['encouragement', 'entitlement', 'entails', 'entailed', 'entailment'],
"contradiction": ['contradictory', 'contradicts', 'contradiction'],
"neutral": ['neutral']}
metric_list:
- metric: f1
aggregation: !function utils.weighted_f1_score
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment