"llama/sampling.cpp" did not exist on "de982616f1dde636e46b2cef2edd971b54ef7691"
Commit c3dabb32 authored by lintangsutawika's avatar lintangsutawika
Browse files

updated filter process

parent 4a3c1f19
...@@ -56,7 +56,7 @@ class TaskConfig(dict): ...@@ -56,7 +56,7 @@ class TaskConfig(dict):
gold_alias: str = None gold_alias: str = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
delimiter: str = "\n\n" delimiter: str = "\n\n"
filters: Union[str, list] = None filter_list: Union[str, list] = None
normalization: str = None # TODO: add length-normalization of various types, mutual info normalization: str = None # TODO: add length-normalization of various types, mutual info
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
...@@ -428,7 +428,11 @@ class ConfigurableTask(Task): ...@@ -428,7 +428,11 @@ class ConfigurableTask(Task):
CONFIG = None CONFIG = None
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: dict=None
): ):
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -489,10 +493,26 @@ class ConfigurableTask(Task): ...@@ -489,10 +493,26 @@ class ConfigurableTask(Task):
self._filters = [] self._filters = []
for name, components in self._config.get("filters", [["none", ["take_first"]]]):
filter_pipeline = build_filter_ensemble(name, components) if self._config.filter_list != None:
for filter_config in self._config.filter_list:
for filter_pipeline in filter_config:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([
function['function'],
kwargs
])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline) self._filters.append(filter_pipeline)
if self.fewshot_docs() != None:
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
def has_training_docs(self): def has_training_docs(self):
...@@ -526,7 +546,7 @@ class ConfigurableTask(Task): ...@@ -526,7 +546,7 @@ class ConfigurableTask(Task):
return self.dataset[self._config.test_split] return self.dataset[self._config.test_split]
def fewshot_docs(self): def fewshot_docs(self):
if (self.num_fewshot > 0) and (self._config.fewshot_split == None): if (self._config.num_fewshot > 0) and (self._config.fewshot_split == None):
eval_logger.warning( eval_logger.warning(
"num_fewshot > 0 but fewshot_split is None", "num_fewshot > 0 but fewshot_split is None",
"using preconfigured rule." "using preconfigured rule."
......
...@@ -17,16 +17,16 @@ def get_filter(filter_name): ...@@ -17,16 +17,16 @@ def get_filter(filter_name):
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
def build_filter_ensemble(name, components): def build_filter_ensemble(filter_name, components):
""" """
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for step in components: for (function, kwargs) in components:
# create a filter given its name in the registry # create a filter given its name in the registry
f = get_filter(step)() # TODO: pass kwargs to filters properly f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
# add the filter as a pipeline step # add the filter as a pipeline step
filters.append(f) filters.append(f)
return FilterEnsemble(name=name, filters=filters) return FilterEnsemble(name=filter_name, filters=filters)
...@@ -9,14 +9,13 @@ class RegexFilter(Filter): ...@@ -9,14 +9,13 @@ class RegexFilter(Filter):
""" """
def __init__(self, regex=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"): def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
""" """
self.regex_pattern = regex self.regex_pattern = regex_pattern
self.regex = re.compile(regex) self.regex = re.compile(regex_pattern)
self.fallback = fallback self.fallback = fallback
def apply(self, resps): def apply(self, resps):
...@@ -30,7 +29,7 @@ class RegexFilter(Filter): ...@@ -30,7 +29,7 @@ class RegexFilter(Filter):
match = self.regex.search(resp) match = self.regex.search(resp)
if match: if match:
match = match.group(1).strip() match = match.group(1).strip()
match_str.replace(",", "") match.replace(",", "")
# TODO: should we assume any other filtering is performed? # TODO: should we assume any other filtering is performed?
else: else:
match = self.fallback match = self.fallback
......
...@@ -40,6 +40,16 @@ metric_list: ...@@ -40,6 +40,16 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
delimiter: "\n" delimiter: "\n"
# filters: [ filter_list:
# ["regex", ["regex", "take_first"]] - name: "just regex"
# ] filter:
\ No newline at end of file - function: "regex"
regex_pattern: ".*"
- function: "regex"
regex_pattern: ".*"
- name: "another regex"
filter:
- function: "regex"
regex_pattern: ".*"
- function: "regex"
regex_pattern: ".*"
\ No newline at end of file
import os import os
import yaml
import json import json
import fnmatch import fnmatch
import warnings
import argparse import argparse
from pprint import pformat
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.tasks import ALL_TASKS from lm_eval.tasks import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -22,7 +18,7 @@ class MultiChoice: ...@@ -22,7 +18,7 @@ class MultiChoice:
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value)) eval_logger.warning("{} is not in task list.".format(value))
# eval_logger.info(f"{ALL_TASKS} is this") # eval_logger.info(f"{choices} is this")
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment