Commit 3e3a0d8f authored by Baber's avatar Baber
Browse files

Merge branch 'rm_multiple_target' into metrics

# Conflicts:
#	lm_eval/api/filter.py
#	lm_eval/api/metrics.py
#	lm_eval/api/task.py
#	lm_eval/filters/extraction.py
parents 2b4cdd41 00a77ebd
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Iterable, List, Union
from lm_eval.api.instance import Instance
......@@ -20,7 +20,9 @@ class Filter(ABC):
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
name: str
filters: List[type[Filter]]
filters: list[type[Filter]]
def apply(self, instances: List[Instance]) -> None:
def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
......
......@@ -207,13 +207,48 @@ def acc_mutual_info_fn(items): # This is a passthrough function
# See the License for the specific language governing permissions and
# limitations under the License.
def exact_match_hf_evaluate(
predictions: Iterable[str],
references: Iterable[str],
regexes_to_ignore=None,
ignore_case=False,
ignore_punctuation=False,
ignore_numbers=False,
predictions: Iterable[str] | str,
references: Iterable[str] | str,
regexes_to_ignore: list[str] | None = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
multi_target: bool = False,
):
"""
Compute exact match scores between predictions and references.
This function computes the exact match score by comparing predictions
and references. It supports optional preprocessing steps such as ignoring
case, punctuation, numbers, and specific regex patterns.
Note:
predictions and references can have different lengths.
numpy broadcasting rule applies
Args:
predictions (Iterable[str] | str): The predicted strings to evaluate.
references (Iterable[str] | str): The reference strings to compare against.
regexes_to_ignore (list[str], optional): A list of regex patterns to remove
from both predictions and references before comparison. Defaults to None.
ignore_case (bool, optional): If True, ignores case differences during comparison.
Defaults to False.
ignore_punctuation (bool, optional): If True, removes punctuation from strings
before comparison. Defaults to False.
ignore_numbers (bool, optional): If True, removes numeric characters from strings
before comparison. Defaults to False.
multi_target (bool, optional): If True, returns 1.0 if any prediction matches any
reference, otherwise 0.0. Defaults to False.
Returns:
dict: A dictionary containing the exact match score:
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
"""
predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, (
"predictions and references must have the same length unless `multi_target` is True"
)
if regexes_to_ignore is not None:
for s in regexes_to_ignore:
predictions = np.array([re.sub(s, "", x) for x in predictions])
......@@ -238,7 +273,11 @@ def exact_match_hf_evaluate(
score_list = predictions == references
return {"exact_match": np.mean(score_list)}
return {
"exact_match": np.mean(score_list)
if not multi_target
else float(np.any(score_list))
}
###
......@@ -250,8 +289,8 @@ def exact_match_hf_evaluate(
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match_hf_evaluate(**kwargs)
def exact_match_fn(references: list[str], predictions: list[str], **kwargs):
return exact_match_hf_evaluate(predictions, references, **kwargs)
@register_metric(
......
......@@ -3,17 +3,14 @@ import ast
import logging
import random
import re
from collections.abc import Callable
from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
......@@ -530,8 +527,8 @@ class Task(abc.ABC):
# self.aggregation = lambda: {
# metric_name: get_metric_aggregation(metric_name)
# }
setattr(self._config, "metric_list", [MetricConfig(name=metric_name)])
setattr(self._config, "process_results", lambda *args: {"bypass": 0})
self._config.metric_list = [MetricConfig(name=metric_name)]
self._config.process_results = lambda *args: {"bypass": 0}
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
self.fewshot_rnd = random.Random(seed)
......@@ -788,7 +785,7 @@ class ConfigurableTask(Task):
return docs
# Fallback to parent implementation
if _num_fewshot := getattr(self.config, "num_fewshot"):
if _num_fewshot := self.config.num_fewshot:
if isinstance(_num_fewshot, int) and _num_fewshot > 0:
eval_logger.warning(
f"[Task: {self.config.task}] "
......@@ -1409,63 +1406,15 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
# we expect multiple_targets to be a list.
elif self.multiple_target:
gold = list(gold)
# TODO: handle this better
elif type(gold) is not type(result) and not (
"bypass" in use_metric or isinstance(result, list)
):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self.config._metric_list:
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
# TODO: this may break for multipLe_target, non zero-or-1 metrics
scores = []
if not isinstance(gold, list):
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
# print(gold)
gold = [gold]
if metric.name == "exact_match":
result = [result for _ in range(len(gold))]
scores = metric.fn(
references=gold,
predictions=result,
**metric.kwargs,
)[metric]
result_score = 1.0 if scores > 0.0 else 0.0
else:
for gold_option in gold:
try:
result_score = metric.fn(
references=[gold_option],
predictions=[result],
**metric.kwargs,
)
except (
TypeError
): # TODO: this is hacky and I don't want to do it
result_score = metric.fn([gold_option, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
result_score = result_score[metric]
scores.append(result_score)
if any(scores):
result_score = 1.0
else:
result_score = 0.0
else:
try:
result_score = metric.fn(
references=[gold],
predictions=[result],
**metric.kwargs,
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = metric.fn([gold, result])
for metric in self._metric_fn_list.keys():
try:
result_score = self._metric_fn_list[metric](
references=[gold] if not isinstance(gold, list) else gold,
predictions=[result],
**self._metric_fn_kwargs[metric],
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
......@@ -1515,7 +1464,7 @@ class MultipleChoiceTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice)),
arguments=(ctx, f" {choice}"),
idx=i,
**kwargs,
)
......
import re
import sys
import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
......@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self.group_select = group_select
self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace")
class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst):
filtered_resp = []
for resp in inst:
......@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......
This diff is collapsed.
......@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos
abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages."
}
```
## Changelog
- 2025-07-21: Refactored. Scores should not be affected.
......@@ -14,19 +14,18 @@ validation_split: validation
test_split: test
fewshot_split: train
doc_to_target: !function utils.doc_to_target
process_results: !function utils.process_results
should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
filter_list:
- filter:
- function: regex_pos
- function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract
metric_list:
- metric: acc
aggregation: !function utils.acc_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata:
version: 1.0
from itertools import chain
import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc):
pos_tag_map = {
......@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items):
unzipped_list = list(zip(*items))
def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1]
filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
return filtered_resps
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list
accuracy = accuracy_score(gold, pred)
accuracy_scores.append(accuracy)
def process_results(doc: dict[str, Any], results: list[list[str]]):
golds, preds = doc_to_target(doc), results[0]
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = (
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
return {"acc": accuracy}
......@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target
should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list:
- filter:
- function: regex_pos
- function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract
metric_list:
- metric: acc
aggregation: !function utils.acc_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata:
version: 1.0
from itertools import chain
import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc):
pos_tag_map = {
......@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items):
unzipped_list = list(zip(*items))
def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1]
filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
return filtered_resps
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list
accuracy = accuracy_score(gold, pred)
accuracy_scores.append(accuracy)
def process_results(doc: dict[str, Any], results: list[list[str]]):
golds, preds = doc_to_target(doc), results[0]
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = (
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
return {"acc": accuracy}
......@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target
should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list:
- filter:
- function: regex_pos
- function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract
metric_list:
- metric: acc
aggregation: !function utils.acc_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata:
version: 1.0
from itertools import chain
import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc):
pos_tag_map = {
......@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items):
unzipped_list = list(zip(*items))
def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1]
filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
return filtered_resps
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list
accuracy = accuracy_score(gold, pred)
accuracy_scores.append(accuracy)
def process_results(doc: dict[str, Any], results: list[list[str]]):
golds, preds = doc_to_target(doc), results[0]
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = (
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
return {"acc": accuracy}
......@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target
should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list:
- filter:
- function: regex_pos
- function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract
metric_list:
- metric: acc
aggregation: !function utils.acc_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata:
version: 1.0
from itertools import chain
import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc):
pos_tag_map = {
......@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items):
unzipped_list = list(zip(*items))
def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1]
filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
return filtered_resps
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list
accuracy = accuracy_score(gold, pred)
accuracy_scores.append(accuracy)
def process_results(doc: dict[str, Any], results: list[list[str]]):
golds, preds = doc_to_target(doc), results[0]
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = (
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
return {"acc": accuracy}
......@@ -16,17 +16,16 @@ fewshot_split: train
doc_to_target: !function utils.doc_to_target
should_decontaminate: true
doc_to_decontamination_query: "Sentence: {{token}}\nOutput:"
process_results: !function utils.process_results
filter_list:
- filter:
- function: regex_pos
- function: "custom"
filter_fn: !function utils.extract_pos
- function: "take_first"
name: flexible-extract
metric_list:
- metric: acc
aggregation: !function utils.acc_score
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- ","
metadata:
version: 1.0
from itertools import chain
import re
from collections.abc import Iterable
from typing import Any
from sklearn.metrics import accuracy_score
from lm_eval.utils import weighted_f1_score
def doc_to_target(doc):
pos_tag_map = {
......@@ -29,27 +29,40 @@ def doc_to_target(doc):
return [pos_tag_map[tag] for tag in doc["upos"]]
def acc_score(items):
unzipped_list = list(zip(*items))
def extract_pos(resps: Iterable[list[str]], *args) -> Iterable[list[str]]:
def extract_tagged_tokens(text: str) -> list[tuple[str, str]]:
# Extract tagged tokens list from text input using regex
tokens = re.findall(
r"\('([^']*)', '([^']*)'\)",
"Here are some tuples: ('apple', 'red'), ('banana', 'yellow'), ('grape', 'purple')",
)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result: str):
pos_tags = []
if isinstance(result, str):
result_ = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result_)
return pos_tags if pos_tags else ["invalid"]
def filter_set(inst: list[str]) -> list[str]:
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
golds, preds = unzipped_list[0], unzipped_list[1]
filtered_resps = map(lambda x: filter_set(x), resps)
# Flatten preds' inner lists
flattened_preds = [list(chain.from_iterable(p)) for p in preds]
return filtered_resps
# Calculate the accuracy for each gold-pred pair
accuracy_scores = []
for gold, pred in zip(golds, flattened_preds):
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(gold), len(pred))
gold = gold[:min_length]
pred = pred[:min_length]
# Calculate accuracy for the current pair and add to the list
accuracy = accuracy_score(gold, pred)
accuracy_scores.append(accuracy)
def process_results(doc: dict[str, Any], results: list[list[str]]):
golds, preds = doc_to_target(doc), results[0]
# Ensure both lists are of the same length, otherwise truncate to match
min_length = min(len(golds), len(preds))
gold = golds[:min_length]
pred = preds[:min_length]
accuracy = accuracy_score(gold, pred)
mean_accuracy = (
sum(accuracy_scores) / len(accuracy_scores) if accuracy_scores else 0
)
return mean_accuracy
return {"acc": accuracy}
from lm_eval.utils import weighted_f1_score
def doc_to_text(doc):
output = """Please provide the POS tags for each word in the input sentence. The input will be a list of words in
the sentence. The output format should be a list of tuples, where each tuple consists of a word from the input text
......
tag: glue
task: cola
dataset_path: glue
dataset_path: nyu-mll/glue
dataset_name: cola
output_type: multiple_choice
training_split: train
......
tag: glue
task: mnli
dataset_path: glue
dataset_path: nyu-mll/glue
dataset_name: mnli
output_type: multiple_choice
training_split: train
......
tag: glue
task: mrpc
dataset_path: glue
dataset_path: nyu-mll/glue
dataset_name: mrpc
output_type: multiple_choice
training_split: train
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment