Unverified Commit d27c0c08 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Apply code autoformatting with Ruff to tasks/*.py an *__init__.py (#1469)

parent f78e2da4
from typing import List, Union
from functools import partial from functools import partial
from typing import List, Union
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from . import selection
from . import extraction from . import extraction, selection, transformation
from . import transformation
FILTER_REGISTRY = { FILTER_REGISTRY = {
......
from . import huggingface from . import (
from . import openai_completions anthropic_llms,
from . import textsynth dummy,
from . import dummy gguf,
from . import anthropic_llms huggingface,
from . import gguf mamba_lm,
from . import vllm_causallms neuron_optimum,
from . import mamba_lm openai_completions,
from . import optimum_lm optimum_lm,
from . import neuron_optimum textsynth,
vllm_causallms,
)
# TODO: implement __all__ # TODO: implement __all__
......
import os
import ast import ast
import os
from typing import Dict from typing import Dict
from lm_eval import utils from lm_eval import utils
from lm_eval.utils import eval_logger from lm_eval.utils import eval_logger
# Prompt library. # Prompt library.
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name. # prompt category name, and prompt name.
......
import os
import abc import abc
import collections import collections
import logging
import os
from functools import partial from functools import partial
from typing import List, Union, Dict from typing import Dict, List, Union
from lm_eval import utils from lm_eval import utils
from lm_eval.api.task import Task, ConfigurableTask from lm_eval.api.task import ConfigurableTask, Task
import logging
class TaskManager: class TaskManager:
...@@ -16,20 +14,14 @@ class TaskManager: ...@@ -16,20 +14,14 @@ class TaskManager:
and an optional directory if provided. and an optional directory if provided.
""" """
def __init__(
self,
verbosity="INFO",
include_path=None
) -> None:
def __init__(self, verbosity="INFO", include_path=None) -> None:
self.verbosity = verbosity self.verbosity = verbosity
self.include_path = include_path self.include_path = include_path
self.logger = utils.eval_logger self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}")) self.logger.setLevel(getattr(logging, f"{verbosity}"))
self._task_index = self.initialize_tasks( self._task_index = self.initialize_tasks(include_path=include_path)
include_path=include_path
)
self._all_tasks = sorted(list(self._task_index.keys())) self._all_tasks = sorted(list(self._task_index.keys()))
self.task_group_map = collections.defaultdict(list) self.task_group_map = collections.defaultdict(list)
...@@ -65,27 +57,29 @@ class TaskManager: ...@@ -65,27 +57,29 @@ class TaskManager:
return self._task_index return self._task_index
def match_tasks(self, task_list): def match_tasks(self, task_list):
return utils.pattern_match( return utils.pattern_match(task_list, self.all_tasks)
task_list, self.all_tasks
)
def _name_is_registered(self, name): def _name_is_registered(self, name):
if name in self.all_tasks: if name in self.all_tasks:
return True return True
return False return False
def _name_is_task(self, name): def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]): if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
return True return True
return False return False
def _name_is_group(self, name): def _name_is_group(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "group"): if self._name_is_registered(name) and (
self.task_index[name]["type"] == "group"
):
return True return True
return False return False
def _name_is_python_task(self, name): def _name_is_python_task(self, name):
if self._name_is_registered(name) and (self.task_index[name]["type"] == "python_task"): if self._name_is_registered(name) and (
self.task_index[name]["type"] == "python_task"
):
return True return True
return False return False
...@@ -117,7 +111,7 @@ class TaskManager: ...@@ -117,7 +111,7 @@ class TaskManager:
return utils.load_yaml_config(yaml_path, mode="full") return utils.load_yaml_config(yaml_path, mode="full")
def _get_tasklist(self, name): def _get_tasklist(self, name):
assert self._name_is_task(name) == False assert self._name_is_task(name) is False
return self.task_index[name]["task"] return self.task_index[name]["task"]
def _process_alias(self, config, group=None): def _process_alias(self, config, group=None):
...@@ -130,12 +124,12 @@ class TaskManager: ...@@ -130,12 +124,12 @@ class TaskManager:
return config return config
def _load_individual_task_or_group( def _load_individual_task_or_group(
self, self,
name_or_config: Union[str, dict] = None, name_or_config: Union[str, dict] = None,
parent_name: str = None, parent_name: str = None,
update_config: dict = None, update_config: dict = None,
yaml_path: str = None, yaml_path: str = None,
) -> ConfigurableTask: ) -> ConfigurableTask:
def load_task(config, task, group=None, yaml_path=None): def load_task(config, task, group=None, yaml_path=None):
if "include" in config: if "include" in config:
assert yaml_path is not None assert yaml_path is not None
...@@ -174,7 +168,9 @@ class TaskManager: ...@@ -174,7 +168,9 @@ class TaskManager:
group_config = self._get_config(name_or_config) group_config = self._get_config(name_or_config)
if set(group_config.keys()) > set(["task", "group"]): if set(group_config.keys()) > set(["task", "group"]):
update_config = { update_config = {
k:v for k,v in group_config.items() if k not in ["task", "group"] k: v
for k, v in group_config.items()
if k not in ["task", "group"]
} }
yaml_path = self._get_yaml_path(group_name) yaml_path = self._get_yaml_path(group_name)
...@@ -183,9 +179,8 @@ class TaskManager: ...@@ -183,9 +179,8 @@ class TaskManager:
update_config.pop("group_alias") update_config.pop("group_alias")
if isinstance(name_or_config, dict): if isinstance(name_or_config, dict):
if update_config is not None: if update_config is not None:
name_or_config={ name_or_config = {
**name_or_config, **name_or_config,
**update_config, **update_config,
} }
...@@ -196,7 +191,9 @@ class TaskManager: ...@@ -196,7 +191,9 @@ class TaskManager:
# if self._name_is_task(name) is False: # if self._name_is_task(name) is False:
if self._name_is_group(name): if self._name_is_group(name):
group_name = name group_name = name
update_config = {k:v for k,v in name_or_config.items() if k != "task"} update_config = {
k: v for k, v in name_or_config.items() if k != "task"
}
subtask_list = self._get_tasklist(name) subtask_list = self._get_tasklist(name)
if subtask_list == -1: if subtask_list == -1:
subtask_list = self._get_config(name)["task"] subtask_list = self._get_config(name)["task"]
...@@ -207,36 +204,53 @@ class TaskManager: ...@@ -207,36 +204,53 @@ class TaskManager:
# Check if this is a duplicate. # Check if this is a duplicate.
if parent_name is not None: if parent_name is not None:
name_or_config["group"] = parent_name name_or_config["group"] = parent_name
num_duplicate = len(list(filter(lambda x: x.startswith(name), self.task_group_map[parent_name]))) num_duplicate = len(
list(
filter(
lambda x: x.startswith(name),
self.task_group_map[parent_name],
)
)
)
if num_duplicate > 0: if num_duplicate > 0:
name = f"{name}-{num_duplicate}" name = f"{name}-{num_duplicate}"
self.task_group_map[parent_name].append(name) self.task_group_map[parent_name].append(name)
task_config={ task_config = {
**base_task_config, **base_task_config,
**name_or_config, **name_or_config,
} }
else: else:
task_config = name_or_config task_config = name_or_config
return load_task(task_config, task=name, group=parent_name, yaml_path=yaml_path) return load_task(
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else: else:
group_name = name_or_config["group"] group_name = name_or_config["group"]
subtask_list = name_or_config["task"] subtask_list = name_or_config["task"]
# update_config = {k:v for k,v in name_or_config.items() if k != "task"}
if set(name_or_config.keys()) > set(["task", "group"]): if set(name_or_config.keys()) > set(["task", "group"]):
update_config = { update_config = {
k:v for k,v in name_or_config.items() if k not in ["task", "group"] k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
} }
all_subtasks = {} all_subtasks = {}
if (parent_name is not None): if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)} all_subtasks = {group_name: (parent_name, None)}
fn = partial(self._load_individual_task_or_group, parent_name=group_name, update_config=update_config, yaml_path=yaml_path) fn = partial(
all_subtasks = {**all_subtasks, **dict(collections.ChainMap(*map(fn, subtask_list)))} self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
yaml_path=yaml_path,
)
all_subtasks = {
**all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
}
return all_subtasks return all_subtasks
def load_task_or_group(self, task_list: Union[str, list] = None) -> dict: def load_task_or_group(self, task_list: Union[str, list] = None) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
...@@ -250,12 +264,7 @@ class TaskManager: ...@@ -250,12 +264,7 @@ class TaskManager:
task_list = [task_list] task_list = [task_list]
all_loaded_tasks = dict( all_loaded_tasks = dict(
collections.ChainMap( collections.ChainMap(*map(self._load_individual_task_or_group, task_list))
*map(
self._load_individual_task_or_group,
task_list
)
)
) )
return all_loaded_tasks return all_loaded_tasks
...@@ -299,11 +308,11 @@ class TaskManager: ...@@ -299,11 +308,11 @@ class TaskManager:
# This is a group config # This is a group config
tasks_and_groups[config["group"]] = { tasks_and_groups[config["group"]] = {
"type": "group", "type": "group",
"task": -1, # This signals that "task": -1, # This signals that
# we don't need to know # we don't need to know
# the task list for indexing # the task list for indexing
# as it can be loaded # as it can be loaded
# when called. # when called.
"yaml_path": yaml_path, "yaml_path": yaml_path,
} }
...@@ -322,7 +331,7 @@ class TaskManager: ...@@ -322,7 +331,7 @@ class TaskManager:
tasks_and_groups[task] = { tasks_and_groups[task] = {
"type": "task", "type": "task",
"yaml_path": yaml_path, "yaml_path": yaml_path,
} }
if "group" in config: if "group" in config:
groups = config["group"] groups = config["group"]
...@@ -343,6 +352,7 @@ class TaskManager: ...@@ -343,6 +352,7 @@ class TaskManager:
return tasks_and_groups return tasks_and_groups
def include_path(task_dir): def include_path(task_dir):
logger = utils.eval_logger logger = utils.eval_logger
logger.setLevel(getattr(logging, "INFO")) logger.setLevel(getattr(logging, "INFO"))
...@@ -352,6 +362,7 @@ def include_path(task_dir): ...@@ -352,6 +362,7 @@ def include_path(task_dir):
) )
return 0 return 0
def initialize_tasks(verbosity="INFO"): def initialize_tasks(verbosity="INFO"):
logger = utils.eval_logger logger = utils.eval_logger
logger.setLevel(getattr(logging, f"{verbosity}")) logger.setLevel(getattr(logging, f"{verbosity}"))
...@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"): ...@@ -362,6 +373,7 @@ def initialize_tasks(verbosity="INFO"):
) )
return 0 return 0
def get_task_name_from_config(task_config: Dict[str, str]) -> str: def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "task" in task_config: if "task" in task_config:
return task_config["task"] return task_config["task"]
...@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str: ...@@ -370,6 +382,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def get_task_name_from_object(task_object): def get_task_name_from_object(task_object):
if hasattr(task_object, "config"): if hasattr(task_object, "config"):
return task_object._config["task"] return task_object._config["task"]
...@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object): ...@@ -382,7 +395,10 @@ def get_task_name_from_object(task_object):
else type(task_object).__name__ else type(task_object).__name__
) )
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None):
def get_task_dict(
task_name_list: List[Union[str, Dict, Task]], task_manager: TaskManager = None
):
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object. """Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]] :param task_name_list: List[Union[str, Dict, Task]]
...@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta ...@@ -409,7 +425,9 @@ def get_task_dict(task_name_list: List[Union[str, Dict, Task]], task_manager: Ta
if task_manager is None: if task_manager is None:
task_manager = TaskManager() task_manager = TaskManager()
task_name_from_string_dict = task_manager.load_task_or_group(string_task_name_list) task_name_from_string_dict = task_manager.load_task_or_group(
string_task_name_list
)
for task_element in others_task_name_list: for task_element in others_task_name_list:
if isinstance(task_element, dict): if isinstance(task_element, dict):
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import argparse
import os import os
import re import re
import yaml
import requests
import argparse
import datasets import datasets
import requests
import yaml
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
......
import collections import collections
import re import re
import sys import sys
import unicodedata import unicodedata
from lm_eval.filters.extraction import RegexFilter, Filter from lm_eval.filters.extraction import Filter, RegexFilter
class ExtendedRegexFilter(RegexFilter): class ExtendedRegexFilter(RegexFilter):
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) punct_tbl = dict.fromkeys(
if unicodedata.category(chr(i)).startswith('P')) i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
)
def __init__( def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]", self,
ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None: ) -> None:
super().__init__(regex_pattern, group_select, fallback) super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case self.ignore_case = ignore_case
...@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter): ...@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
class MapRegexFilter(ExtendedRegexFilter): class MapRegexFilter(ExtendedRegexFilter):
def __init__( def __init__(
self, regex_pattern_to_value: dict = {}, group_select=0, fallback: str = "[invalid]", self,
ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, regex_pattern_to_value: dict = {},
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None: ) -> None:
""" """
regex_pattern_to_value: Match the regex pattern and change the result into the value regex_pattern_to_value: Match the regex pattern and change the result into the value
...@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter): ...@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
ignore_punctuation: Remove the punctuation before matching with the given regex ignore_punctuation: Remove the punctuation before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex regexes_to_ignore: Remove these regexes before matching with the given regex
""" """
super().__init__('|'.join(list(regex_pattern_to_value.keys())), group_select, fallback, ignore_case, ignore_punctuation, regexes_to_ignore) super().__init__(
self.regex_to_value = {re.compile(r): v for r, v in regex_pattern_to_value.items()} "|".join(list(regex_pattern_to_value.keys())),
group_select,
fallback,
ignore_case,
ignore_punctuation,
regexes_to_ignore,
)
self.regex_to_value = {
re.compile(r): v for r, v in regex_pattern_to_value.items()
}
def apply(self, resps, docs): def apply(self, resps, docs):
filtered_resps = [] filtered_resps = []
...@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter): ...@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
for r in resps: for r in resps:
filtered = [] filtered = []
for resp in r: for resp in r:
whole_match_considering_group_select = self.find_match(self.regex, self.filter_ignores(resp)) whole_match_considering_group_select = self.find_match(
self.regex, self.filter_ignores(resp)
)
if whole_match_considering_group_select: if whole_match_considering_group_select:
for regex, mapped_value in self.regex_to_value.items(): for regex, mapped_value in self.regex_to_value.items():
match = self.find_match(regex, self.filter_ignores(whole_match_considering_group_select)) match = self.find_match(
regex,
self.filter_ignores(whole_match_considering_group_select),
)
if match: if match:
match = mapped_value match = mapped_value
break break
...@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter): ...@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
filtered_resps = [] filtered_resps = []
import regex import regex
from word2number import w2n from word2number import w2n
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words # https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
english_number_regex = regex.compile( english_number_regex = regex.compile(
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))") "((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
for r in resps: for r in resps:
filtered = [] filtered = []
...@@ -118,21 +144,22 @@ class WordSortFilter(Filter): ...@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
filtered_resps = [] filtered_resps = []
for r, doc in zip(resps, docs): for r, doc in zip(resps, docs):
words = doc['input'].split("List:")[1].strip().split() words = doc["input"].split("List:")[1].strip().split()
regex = re.compile('|'.join([f"\\b{w}\\b" for w in words])) regex = re.compile("|".join([f"\\b{w}\\b" for w in words]))
filtered = [] filtered = []
for resp in r: for resp in r:
match = regex.findall(resp) match = regex.findall(resp)
match.reverse() match.reverse()
ordered_words = reversed(collections.OrderedDict(zip(match, [None] * len(match)))) ordered_words = reversed(
filtered.append(' '.join(ordered_words)) collections.OrderedDict(zip(match, [None] * len(match)))
)
filtered.append(" ".join(ordered_words))
filtered_resps.append(filtered) filtered_resps.append(filtered)
return filtered_resps return filtered_resps
class MultiChoiceRegexFilter(ExtendedRegexFilter): class MultiChoiceRegexFilter(ExtendedRegexFilter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
...@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter): ...@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
for r, doc in zip(resps, docs): for r, doc in zip(resps, docs):
fallback_regexes = [] fallback_regexes = []
choice_to_alpha = {} choice_to_alpha = {}
next_alpha = 'A' next_alpha = "A"
without_paren_fallback_regexes = [] without_paren_fallback_regexes = []
without_paren_to_target = {} without_paren_to_target = {}
multiple_choices_regex = re.compile(r"\([A-Z]\)([^\n^(]*)") multiple_choices_regex = re.compile(r"\([A-Z]\)([^\n^(]*)")
match = multiple_choices_regex.findall(doc['input']) match = multiple_choices_regex.findall(doc["input"])
for m in match: for m in match:
m = self.filter_ignores(m.strip()) m = self.filter_ignores(m.strip())
fallback_regexes.append(f"{re.escape(m)}") fallback_regexes.append(f"{re.escape(m)}")
...@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter): ...@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
without_paren_to_target[next_alpha] = f"({next_alpha})" without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1) next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile('|'.join(fallback_regexes)) fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes) without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})") without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})"
)
filtered = [] filtered = []
for resp in r: for resp in r:
match = self.find_match(self.regex, resp) match = self.find_match(self.regex, resp)
if not match: if not match:
match = self.find_match(fallback_regex, self.filter_ignores(resp), choice_to_alpha) match = self.find_match(
fallback_regex, self.filter_ignores(resp), choice_to_alpha
)
if not match: if not match:
match = self.find_match(without_paren_fallback_regex, resp, without_paren_to_target) match = self.find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match: if not match:
match = self.fallback match = self.fallback
filtered.append(match) filtered.append(match)
......
import collections import collections
import re import re
import sys import sys
import unicodedata import unicodedata
from lm_eval.filters.extraction import RegexFilter, Filter from lm_eval.filters.extraction import Filter, RegexFilter
class ExtendedRegexFilter(RegexFilter): class ExtendedRegexFilter(RegexFilter):
punct_tbl = dict.fromkeys(i for i in range(sys.maxunicode) punct_tbl = dict.fromkeys(
if unicodedata.category(chr(i)).startswith('P')) i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
)
def __init__( def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]", self,
ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None: ) -> None:
super().__init__(regex_pattern, group_select, fallback) super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case self.ignore_case = ignore_case
...@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter): ...@@ -47,8 +52,13 @@ class ExtendedRegexFilter(RegexFilter):
class MapRegexFilter(ExtendedRegexFilter): class MapRegexFilter(ExtendedRegexFilter):
def __init__( def __init__(
self, regex_pattern_to_value: dict = {}, group_select=0, fallback: str = "[invalid]", self,
ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, regex_pattern_to_value: dict = {},
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None: ) -> None:
""" """
regex_pattern_to_value: Match the regex pattern and change the result into the value regex_pattern_to_value: Match the regex pattern and change the result into the value
...@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter): ...@@ -57,8 +67,17 @@ class MapRegexFilter(ExtendedRegexFilter):
ignore_punctuation: Remove the punctuation before matching with the given regex ignore_punctuation: Remove the punctuation before matching with the given regex
regexes_to_ignore: Remove these regexes before matching with the given regex regexes_to_ignore: Remove these regexes before matching with the given regex
""" """
super().__init__('|'.join(list(regex_pattern_to_value.keys())), group_select, fallback, ignore_case, ignore_punctuation, regexes_to_ignore) super().__init__(
self.regex_to_value = {re.compile(r): v for r, v in regex_pattern_to_value.items()} "|".join(list(regex_pattern_to_value.keys())),
group_select,
fallback,
ignore_case,
ignore_punctuation,
regexes_to_ignore,
)
self.regex_to_value = {
re.compile(r): v for r, v in regex_pattern_to_value.items()
}
def apply(self, resps, docs): def apply(self, resps, docs):
filtered_resps = [] filtered_resps = []
...@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter): ...@@ -66,10 +85,15 @@ class MapRegexFilter(ExtendedRegexFilter):
for r in resps: for r in resps:
filtered = [] filtered = []
for resp in r: for resp in r:
whole_match_considering_group_select = self.find_match(self.regex, self.filter_ignores(resp)) whole_match_considering_group_select = self.find_match(
self.regex, self.filter_ignores(resp)
)
if whole_match_considering_group_select: if whole_match_considering_group_select:
for regex, mapped_value in self.regex_to_value.items(): for regex, mapped_value in self.regex_to_value.items():
match = self.find_match(regex, self.filter_ignores(whole_match_considering_group_select)) match = self.find_match(
regex,
self.filter_ignores(whole_match_considering_group_select),
)
if match: if match:
match = mapped_value match = mapped_value
break break
...@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter): ...@@ -91,9 +115,11 @@ class NumberParseRegexFilter(ExtendedRegexFilter):
filtered_resps = [] filtered_resps = []
import regex import regex
from word2number import w2n from word2number import w2n
# https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words # https://www.reddit.com/r/regex/comments/11a38uk/parsing_numbers_written_out_as_english_words
english_number_regex = regex.compile( english_number_regex = regex.compile(
"((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))") "((?:(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?:|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion)(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?:|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion|[^\S\r\n]|,|and|&)+)?(?:zero|one|two|three|four|five|(?:twen|thir|for|fif|six|seven|nine)(?|teen|ty)|eight(?|een|y)|ten|eleven|twelve|fourteen|hundred|thousand|(?:m|b|tr)illion))"
)
for r in resps: for r in resps:
filtered = [] filtered = []
...@@ -118,21 +144,22 @@ class WordSortFilter(Filter): ...@@ -118,21 +144,22 @@ class WordSortFilter(Filter):
filtered_resps = [] filtered_resps = []
for r, doc in zip(resps, docs): for r, doc in zip(resps, docs):
words = doc['input'].split("List:")[1].strip().split() words = doc["input"].split("List:")[1].strip().split()
regex = re.compile('|'.join([f"\\b{w}\\b" for w in words])) regex = re.compile("|".join([f"\\b{w}\\b" for w in words]))
filtered = [] filtered = []
for resp in r: for resp in r:
match = regex.findall(resp) match = regex.findall(resp)
match.reverse() match.reverse()
ordered_words = reversed(collections.OrderedDict(zip(match, [None] * len(match)))) ordered_words = reversed(
filtered.append(' '.join(ordered_words)) collections.OrderedDict(zip(match, [None] * len(match)))
)
filtered.append(" ".join(ordered_words))
filtered_resps.append(filtered) filtered_resps.append(filtered)
return filtered_resps return filtered_resps
class MultiChoiceRegexFilter(ExtendedRegexFilter): class MultiChoiceRegexFilter(ExtendedRegexFilter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
...@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter): ...@@ -156,13 +183,13 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
for r, doc in zip(resps, docs): for r, doc in zip(resps, docs):
fallback_regexes = [] fallback_regexes = []
choice_to_alpha = {} choice_to_alpha = {}
next_alpha = 'A' next_alpha = "A"
without_paren_fallback_regexes = [] without_paren_fallback_regexes = []
without_paren_to_target = {} without_paren_to_target = {}
multiple_choices_regex = re.compile(r"\([A-Z]\)([^\n^(]*)") multiple_choices_regex = re.compile(r"\([A-Z]\)([^\n^(]*)")
match = multiple_choices_regex.findall(doc['input']) match = multiple_choices_regex.findall(doc["input"])
for m in match: for m in match:
m = self.filter_ignores(m.strip()) m = self.filter_ignores(m.strip())
fallback_regexes.append(f"{re.escape(m)}") fallback_regexes.append(f"{re.escape(m)}")
...@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter): ...@@ -172,17 +199,23 @@ class MultiChoiceRegexFilter(ExtendedRegexFilter):
without_paren_to_target[next_alpha] = f"({next_alpha})" without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1) next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile('|'.join(fallback_regexes)) fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = '|'.join(without_paren_fallback_regexes) without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(f":[\s]*({without_paren_fallback_regex})") without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})"
)
filtered = [] filtered = []
for resp in r: for resp in r:
match = self.find_match(self.regex, resp) match = self.find_match(self.regex, resp)
if not match: if not match:
match = self.find_match(fallback_regex, self.filter_ignores(resp), choice_to_alpha) match = self.find_match(
fallback_regex, self.filter_ignores(resp), choice_to_alpha
)
if not match: if not match:
match = self.find_match(without_paren_fallback_regex, resp, without_paren_to_target) match = self.find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match: if not match:
match = self.fallback match = self.fallback
filtered.append(match) filtered.append(match)
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import os
import yaml
import argparse import argparse
import requests import os
import requests
import yaml
from tqdm import tqdm from tqdm import tqdm
from lm_eval.utils import logging from lm_eval.utils import logging
API_URL = "https://datasets-server.huggingface.co/splits?dataset=facebook/belebele" API_URL = "https://datasets-server.huggingface.co/splits?dataset=facebook/belebele"
...@@ -39,6 +40,7 @@ if __name__ == "__main__": ...@@ -39,6 +40,7 @@ if __name__ == "__main__":
def query(): def query():
response = requests.get(API_URL) response = requests.get(API_URL)
return response.json()["splits"] return response.json()["splits"]
print(query()) print(query())
languages = [split["split"] for split in query()] languages = [split["split"] for split in query()]
...@@ -49,7 +51,7 @@ if __name__ == "__main__": ...@@ -49,7 +51,7 @@ if __name__ == "__main__":
if args.task_prefix != "" if args.task_prefix != ""
else f"belebele_{lang}", else f"belebele_{lang}",
"test_split": lang, "test_split": lang,
"fewshot_split":lang, "fewshot_split": lang,
} }
file_save_path = args.save_prefix_path + f"_{lang}.yaml" file_save_path = args.save_prefix_path + f"_{lang}.yaml"
......
import os import os
import yaml import yaml
all_subtasks = [ all_subtasks = [
"abstract_narrative_understanding", "abstract_narrative_understanding",
"anachronisms", "anachronisms",
......
...@@ -8,10 +8,9 @@ Requires the installation of ...@@ -8,10 +8,9 @@ Requires the installation of
`pip install "bigbench @ https://storage.googleapis.com/public_research_data/bigbench/bigbench-0.0.1.tar.gz"` `pip install "bigbench @ https://storage.googleapis.com/public_research_data/bigbench/bigbench-0.0.1.tar.gz"`
and is included so that the bigbench dependency can be avoided. and is included so that the bigbench dependency can be avoided.
""" """
from tqdm import tqdm
import datasets
import bigbench.api.util as bb_utils import bigbench.api.util as bb_utils
import datasets
from tqdm import tqdm
all_task_names = bb_utils.get_all_json_task_names() all_task_names = bb_utils.get_all_json_task_names()
......
import yaml import yaml
all_subtasks = [ all_subtasks = [
"adjunct_island", "adjunct_island",
"anaphor_gender_agreement", "anaphor_gender_agreement",
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import os
import yaml
import argparse import argparse
import os
import yaml
from tqdm import tqdm from tqdm import tqdm
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
SUBJECTS = { SUBJECTS = {
"computer_network": "计算机网络", "computer_network": "计算机网络",
"operating_system": "操作系统", "operating_system": "操作系统",
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import os
import yaml
import argparse import argparse
import os
import yaml
from tqdm import tqdm from tqdm import tqdm
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
SUBJECTS = { SUBJECTS = {
"agronomy": "农学", "agronomy": "农学",
"anatomy": "解剖学", "anatomy": "解剖学",
......
#!/usr/bin/python #!/usr/bin/python
import math
import re import re
import sys import sys
import math
import xml.sax.saxutils import xml.sax.saxutils
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from typing import List, Pattern, Tuple, Union, Dict, Any, Optional
""" """
This script was adapted from the original version by hieuhoang1972 which is part of MOSES. This script was adapted from the original version by hieuhoang1972 which is part of MOSES.
...@@ -60,7 +60,7 @@ def normalize(s): ...@@ -60,7 +60,7 @@ def normalize(s):
# Added to bypass NIST-style pre-processing of hyp and ref files -- wade # Added to bypass NIST-style pre-processing of hyp and ref files -- wade
if nonorm: if nonorm:
return s.split() return s.split()
if type(s) is not str: if not isinstance(s, str):
s = " ".join(s) s = " ".join(s)
# language-independent part: # language-independent part:
for pattern, replace in normalize1: for pattern, replace in normalize1:
......
""" """
Take in a YAML, and output all other splits with this YAML Take in a YAML, and output all other splits with this YAML
""" """
import os
import yaml
import argparse import argparse
import os
import yaml
from tqdm import tqdm from tqdm import tqdm
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
SUBSETS = ["WR", "GR", "RCS", "RCSS", "RCH", "LI"] SUBSETS = ["WR", "GR", "RCS", "RCSS", "RCH", "LI"]
......
...@@ -4,6 +4,7 @@ import string ...@@ -4,6 +4,7 @@ import string
import numpy as np import numpy as np
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
......
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
...@@ -22,5 +21,6 @@ def main() -> None: ...@@ -22,5 +21,6 @@ def main() -> None:
except FileExistsError: except FileExistsError:
pass pass
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import datasets
import re
import random import random
import re
import datasets
def preprocess(text): def preprocess(text):
if text is None: if text is None:
...@@ -11,8 +13,10 @@ def preprocess(text): ...@@ -11,8 +13,10 @@ def preprocess(text):
text = text.replace(" ", " ") text = text.replace(" ", " ")
return text return text
rng = random.Random(42) rng = random.Random(42)
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
def _process_doc(doc): def _process_doc(doc):
choices = [ choices = [
...@@ -30,7 +34,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: ...@@ -30,7 +34,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
"choice2": choices[1], "choice2": choices[1],
"choice3": choices[2], "choice3": choices[2],
"choice4": choices[3], "choice4": choices[3],
"answer": f"({chr(65 + correct_answer_index)})" "answer": f"({chr(65 + correct_answer_index)})",
} }
return out_doc return out_doc
......
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
...@@ -22,5 +21,6 @@ def main() -> None: ...@@ -22,5 +21,6 @@ def main() -> None:
except FileExistsError: except FileExistsError:
pass pass
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import datasets
import re
import random import random
import re
import datasets
def preprocess(text): def preprocess(text):
if text is None: if text is None:
...@@ -29,7 +31,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: ...@@ -29,7 +31,7 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
"choice2": choices[1], "choice2": choices[1],
"choice3": choices[2], "choice3": choices[2],
"choice4": choices[3], "choice4": choices[3],
"answer": f"({chr(65 + correct_answer_index)})" "answer": f"({chr(65 + correct_answer_index)})",
} }
return out_doc return out_doc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment