Commit c3dabb32 authored by lintangsutawika's avatar lintangsutawika
Browse files

updated filter process

parent 4a3c1f19
......@@ -56,7 +56,7 @@ class TaskConfig(dict):
gold_alias: str = None
output_type: str = "greedy_until"
delimiter: str = "\n\n"
filters: Union[str, list] = None
filter_list: Union[str, list] = None
normalization: str = None # TODO: add length-normalization of various types, mutual info
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
......@@ -428,7 +428,11 @@ class ConfigurableTask(Task):
CONFIG = None
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: dict=None
):
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -489,11 +493,27 @@ class ConfigurableTask(Task):
self._filters = []
for name, components in self._config.get("filters", [["none", ["take_first"]]]):
filter_pipeline = build_filter_ensemble(name, components)
if self._config.filter_list != None:
for filter_config in self._config.filter_list:
for filter_pipeline in filter_config:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([
function['function'],
kwargs
])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
if self.fewshot_docs() != None:
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
def has_training_docs(self):
if self._config.training_split is not None:
......@@ -526,7 +546,7 @@ class ConfigurableTask(Task):
return self.dataset[self._config.test_split]
def fewshot_docs(self):
if (self.num_fewshot > 0) and (self._config.fewshot_split == None):
if (self._config.num_fewshot > 0) and (self._config.fewshot_split == None):
eval_logger.warning(
"num_fewshot > 0 but fewshot_split is None",
"using preconfigured rule."
......
......@@ -17,16 +17,16 @@ def get_filter(filter_name):
return FILTER_REGISTRY[filter_name]
def build_filter_ensemble(name, components):
def build_filter_ensemble(filter_name, components):
"""
Create a filtering pipeline.
"""
filters = []
for step in components:
# create a filter given its name in the registry
f = get_filter(step)() # TODO: pass kwargs to filters properly
for (function, kwargs) in components:
# create a filter given its name in the registry
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
# add the filter as a pipeline step
filters.append(f)
# add the filter as a pipeline step
filters.append(f)
return FilterEnsemble(name=name, filters=filters)
return FilterEnsemble(name=filter_name, filters=filters)
......@@ -9,14 +9,13 @@ class RegexFilter(Filter):
"""
def __init__(self, regex=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.regex_pattern = regex
self.regex = re.compile(regex)
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.fallback = fallback
def apply(self, resps):
......@@ -30,7 +29,7 @@ class RegexFilter(Filter):
match = self.regex.search(resp)
if match:
match = match.group(1).strip()
match_str.replace(",", "")
match.replace(",", "")
# TODO: should we assume any other filtering is performed?
else:
match = self.fallback
......
......@@ -40,6 +40,16 @@ metric_list:
ignore_case: true
ignore_punctuation: true
delimiter: "\n"
# filters: [
# ["regex", ["regex", "take_first"]]
# ]
\ No newline at end of file
filter_list:
- name: "just regex"
filter:
- function: "regex"
regex_pattern: ".*"
- function: "regex"
regex_pattern: ".*"
- name: "another regex"
filter:
- function: "regex"
regex_pattern: ".*"
- function: "regex"
regex_pattern: ".*"
\ No newline at end of file
import os
import yaml
import json
import fnmatch
import warnings
import argparse
from pprint import pformat
from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger
......@@ -22,7 +18,7 @@ class MultiChoice:
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value))
# eval_logger.info(f"{ALL_TASKS} is this")
# eval_logger.info(f"{choices} is this")
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment