Commit 3e3a0d8f authored by Baber's avatar Baber
Browse files

Merge branch 'rm_multiple_target' into metrics

# Conflicts:
#	lm_eval/api/filter.py
#	lm_eval/api/metrics.py
#	lm_eval/api/task.py
#	lm_eval/filters/extraction.py
parents 2b4cdd41 00a77ebd
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -20,7 +20,9 @@ class Filter(ABC): ...@@ -20,7 +20,9 @@ class Filter(ABC):
""" """
@abstractmethod @abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g. Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...@@ -40,9 +42,9 @@ class FilterEnsemble: ...@@ -40,9 +42,9 @@ class FilterEnsemble:
""" """
name: str name: str
filters: List[type[Filter]] filters: list[type[Filter]]
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs) resps, docs = list(resps), list(docs)
......
...@@ -207,13 +207,48 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -207,13 +207,48 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
def exact_match_hf_evaluate( def exact_match_hf_evaluate(
predictions: Iterable[str], predictions: Iterable[str] | str,
references: Iterable[str], references: Iterable[str] | str,
regexes_to_ignore=None, regexes_to_ignore: list[str] | None = None,
ignore_case=False, ignore_case: bool = False,
ignore_punctuation=False, ignore_punctuation: bool = False,
ignore_numbers=False, ignore_numbers: bool = False,
multi_target: bool = False,
): ):
"""
Compute exact match scores between predictions and references.
This function computes the exact match score by comparing predictions
and references. It supports optional preprocessing steps such as ignoring
case, punctuation, numbers, and specific regex patterns.
Note:
predictions and references can have different lengths.
numpy broadcasting rule applies
Args:
predictions (Iterable[str] | str): The predicted strings to evaluate.
references (Iterable[str] | str): The reference strings to compare against.
regexes_to_ignore (list[str], optional): A list of regex patterns to remove
from both predictions and references before comparison. Defaults to None.
ignore_case (bool, optional): If True, ignores case differences during comparison.
Defaults to False.
ignore_punctuation (bool, optional): If True, removes punctuation from strings
before comparison. Defaults to False.
ignore_numbers (bool, optional): If True, removes numeric characters from strings
before comparison. Defaults to False.
multi_target (bool, optional): If True, returns 1.0 if any prediction matches any
reference, otherwise 0.0. Defaults to False.
Returns:
dict: A dictionary containing the exact match score:
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
"""
predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, (
"predictions and references must have the same length unless `multi_target` is True"
)
if regexes_to_ignore is not None: if regexes_to_ignore is not None:
for s in regexes_to_ignore: for s in regexes_to_ignore:
predictions = np.array([re.sub(s, "", x) for x in predictions]) predictions = np.array([re.sub(s, "", x) for x in predictions])
...@@ -238,7 +273,11 @@ def exact_match_hf_evaluate( ...@@ -238,7 +273,11 @@ def exact_match_hf_evaluate(
score_list = predictions == references score_list = predictions == references
return {"exact_match": np.mean(score_list)} return {
"exact_match": np.mean(score_list)
if not multi_target
else float(np.any(score_list))
}
### ###
...@@ -250,8 +289,8 @@ def exact_match_hf_evaluate( ...@@ -250,8 +289,8 @@ def exact_match_hf_evaluate(
output_type="generate_until", output_type="generate_until",
aggregation="mean", aggregation="mean",
) )
def exact_match_fn(**kwargs): def exact_match_fn(references: list[str], predictions: list[str], **kwargs):
return exact_match_hf_evaluate(**kwargs) return exact_match_hf_evaluate(predictions, references, **kwargs)
@register_metric( @register_metric(
......
...@@ -3,17 +3,14 @@ import ast ...@@ -3,17 +3,14 @@ import ast
import logging import logging
import random import random
import re import re
from collections.abc import Callable from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Dict, Dict,
Iterable,
Iterator,
List, List,
Literal, Literal,
Mapping,
Optional, Optional,
Tuple, Tuple,
Union, Union,
...@@ -530,8 +527,8 @@ class Task(abc.ABC): ...@@ -530,8 +527,8 @@ class Task(abc.ABC):
# self.aggregation = lambda: { # self.aggregation = lambda: {
# metric_name: get_metric_aggregation(metric_name) # metric_name: get_metric_aggregation(metric_name)
# } # }
setattr(self._config, "metric_list", [MetricConfig(name=metric_name)]) self._config.metric_list = [MetricConfig(name=metric_name)]
setattr(self._config, "process_results", lambda *args: {"bypass": 0}) self._config.process_results = lambda *args: {"bypass": 0}
def set_fewshot_seed(self, seed: Optional[int] = None) -> None: def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed) self.fewshot_rnd = random.Random(seed)
...@@ -788,7 +785,7 @@ class ConfigurableTask(Task): ...@@ -788,7 +785,7 @@ class ConfigurableTask(Task):
return docs return docs
# Fallback to parent implementation # Fallback to parent implementation
if _num_fewshot := getattr(self.config, "num_fewshot"): if _num_fewshot := self.config.num_fewshot:
if isinstance(_num_fewshot, int) and _num_fewshot > 0: if isinstance(_num_fewshot, int) and _num_fewshot > 0:
eval_logger.warning( eval_logger.warning(
f"[Task: {self.config.task}] " f"[Task: {self.config.task}] "
...@@ -1409,63 +1406,15 @@ class ConfigurableTask(Task): ...@@ -1409,63 +1406,15 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number. # it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
gold = choices[gold] gold = choices[gold]
# we expect multiple_targets to be a list. for metric in self._metric_fn_list.keys():
elif self.multiple_target: try:
gold = list(gold) result_score = self._metric_fn_list[metric](
# TODO: handle this better references=[gold] if not isinstance(gold, list) else gold,
elif type(gold) is not type(result) and not ( predictions=[result],
"bypass" in use_metric or isinstance(result, list) **self._metric_fn_kwargs[metric],
): )
# cast gold to the same type as result except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
gold = type(result)(gold) result_score = self._metric_fn_list[metric]([gold, result])
for metric in self.config._metric_list:
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric.name == "exact_match":
result = [result for _ in range(len(gold))]
scores = metric.fn(
references=gold,
predictions=result,
**metric.kwargs,
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = metric.fn(
references=[gold_option],
predictions=[result],
**metric.kwargs,
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = metric.fn([gold_option, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
try:
result_score = metric.fn(
references=[gold],
predictions=[result],
**metric.kwargs,
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = metric.fn([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function # This allows for multiple metrics to be returned from the same function
...@@ -1515,7 +1464,7 @@ class MultipleChoiceTask(Task): ...@@ -1515,7 +1464,7 @@ class MultipleChoiceTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)), arguments=(ctx, f" {choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
import re import re
import sys import sys
import unicodedata import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter from lm_eval.api.registry import register_filter
...@@ -32,7 +33,9 @@ class RegexFilter(Filter): ...@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -59,59 +62,13 @@ class RegexFilter(Filter): ...@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return filtered_resps return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses.""" """Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
...@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
......
This diff is collapsed.
...@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos ...@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos
abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages." abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages."
} }
``` ```
## Changelog
- 2025-07-21: Refactored. Scores should not be affected.
...@@ -14,19 +14,18 @@ validation_split: validation ...@@ -14,19 +14,18 @@ validation_split: validation
test_split: test test_split: test
fewshot_split: train fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
...@@ -16,17 +16,16 @@ fewshot_split: train ...@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target doc_to_target: !function utils.doc_to_target
should_decontaminate: true should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:" doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list: filter_list:
- filter: - filter:
- function: regex_pos - function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract name: flexible-extract
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: !function utils.acc_score aggregation: mean
higher_is_better: true higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata: metadata:
version: 1.0 version: 1.0
from itertools import chain import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc): def doc_to_target(doc):
pos_tag_map = { pos_tag_map = {
...@@ -29,27 +29,40 @@ def doc_to_target(doc): ...@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]] return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items): def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
unzipped_list = list(zip(*items)) def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1] filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists return filtered_resps
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list def process_results(doc: dict[str, Any], results: list[list[str]]):
accuracy = accuracy_score(gold, pred) golds, preds = doc_to_target(doc), results[0]
accuracy_scores.append(accuracy) # Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = ( return {"acc": accuracy}
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
from lm_eval.utils import weighted_f1_score
def doc_to_text(doc): def doc_to_text(doc):
output = """Please provide the POS tags for each word in the input sentence. The input will be a list of words in output = """Please provide the POS tags for each word in the input sentence. The input will be a list of words in
the sentence. The output format should be a list of tuples, where each tuple consists of a word from the input text the sentence. The output format should be a list of tuples, where each tuple consists of a word from the input text
......
tag: glue tag: glue
task: cola task: cola
dataset_path: glue dataset_path: nyu-mll/glue
dataset_name: cola dataset_name: cola
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
tag: glue tag: glue
task: mnli task: mnli
dataset_path: glue dataset_path: nyu-mll/glue
dataset_name: mnli dataset_name: mnli
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
tag: glue tag: glue
task: mrpc task: mrpc
dataset_path: glue dataset_path: nyu-mll/glue
dataset_name: mrpc dataset_name: mrpc
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment