"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2ee59b2bfe607e275ae0df1a9eb536f63aefb5e7"
Commit 458342e2 authored by lintangsutawika's avatar lintangsutawika
Browse files

format

parent b8122d98
...@@ -11,7 +11,9 @@ doc_to_target: answer ...@@ -11,7 +11,9 @@ doc_to_target: answer
filter_list: filter_list:
- name: "custom-extract" - name: "custom-extract"
filter: filter:
- function: !function utils.CustomRegexFilter - function: "regex"
regex_pattern: r"answer is \(?([ABCDEFGHIJ])\)?"
# regex_pattern: r".*[aA]nswer:\s*([A-J])",
- function: "take_first" - function: "take_first"
generation_kwargs: generation_kwargs:
until: until:
......
import re
from functools import partial from functools import partial
from lm_eval.api.filter import Filter
choices = [ choices = [
"A", "A",
...@@ -64,43 +61,3 @@ process_other = partial(process_docs, subject="other") ...@@ -64,43 +61,3 @@ process_other = partial(process_docs, subject="other")
process_philosophy = partial(process_docs, subject="philosophy") process_philosophy = partial(process_docs, subject="philosophy")
process_physics = partial(process_docs, subject="physics") process_physics = partial(process_docs, subject="physics")
process_psychology = partial(process_docs, subject="psychology") process_psychology = partial(process_docs, subject="psychology")
class CustomRegexFilter(Filter):
""" """
def __init__(
self,
regex_pattern: list = [
r"answer is \(?([ABCDEFGHIJ])\)?",
r".*[aA]nswer:\s*([A-J])",
],
group_select=0,
fallback: str = "[invalid]",
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.regex_pattern = regex_pattern
self.regex = [re.compile(pattern) for pattern in regex_pattern]
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
filtered_resps = []
for resp in resps:
for pattern in self.regex:
match = pattern.search(resp)
if match:
filtered_resps.append(match.group(1))
break
if len(filtered_resps) == 0:
filtered_resps = [None]
return filtered_resps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment