Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1554066c
Unverified
Commit
1554066c
authored
Jan 30, 2024
by
Baber Abbasi
Committed by
GitHub
Jan 30, 2024
Browse files
delay filter init; remove `*args` (#1369)
* delay filter init; remove `*args` * bugfix * optimize * type hint
parent
7fc43656
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
15 deletions
+27
-15
lm_eval/api/filter.py
lm_eval/api/filter.py
+5
-5
lm_eval/api/instance.py
lm_eval/api/instance.py
+6
-1
lm_eval/api/task.py
lm_eval/api/task.py
+6
-1
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+6
-6
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+4
-2
No files found.
lm_eval/api/filter.py
View file @
1554066c
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
List
from
typing
import
Callable
,
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
...
...
@@ -14,13 +14,13 @@ class Filter(ABC):
"""
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
**
kwargs
)
->
None
:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
@
abstractmethod
def
apply
(
self
,
resps
,
docs
)
:
def
apply
(
self
,
resps
:
Union
[
List
,
Iterable
],
docs
:
List
[
dict
])
->
Iterable
:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...
...
@@ -40,7 +40,7 @@ class FilterEnsemble:
"""
name
:
str
filters
:
List
[
Filter
]
filters
:
List
[
Callable
[[],
Filter
]
]
def
apply
(
self
,
instances
:
List
[
Instance
])
->
None
:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
...
...
@@ -48,7 +48,7 @@ class FilterEnsemble:
for
f
in
self
.
filters
:
# apply filters in sequence
resps
=
f
.
apply
(
resps
,
docs
)
resps
=
f
()
.
apply
(
resps
,
docs
)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
...
...
lm_eval/api/instance.py
View file @
1554066c
...
...
@@ -4,7 +4,12 @@ from typing import Literal, Tuple
@
dataclass
class
Instance
:
request_type
:
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
]
request_type
:
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
,
]
doc
:
dict
arguments
:
tuple
idx
:
int
...
...
lm_eval/api/task.py
View file @
1554066c
...
...
@@ -74,7 +74,12 @@ class TaskConfig(dict):
num_fewshot
:
int
=
None
# scoring options
metric_list
:
list
=
None
output_type
:
str
=
"generate_until"
output_type
:
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
,
]
=
"generate_until"
generation_kwargs
:
dict
=
None
repeats
:
int
=
1
filter_list
:
Union
[
str
,
list
]
=
None
...
...
lm_eval/filters/__init__.py
View file @
1554066c
from
typing
import
List
from
typing
import
List
,
Union
from
functools
import
partial
from
lm_eval.api.filter
import
FilterEnsemble
from
.
import
selection
...
...
@@ -22,7 +23,7 @@ FILTER_REGISTRY = {
}
def
get_filter
(
filter_name
)
:
def
get_filter
(
filter_name
:
str
)
->
Union
[
type
,
str
]
:
if
filter_name
in
FILTER_REGISTRY
:
return
FILTER_REGISTRY
[
filter_name
]
else
:
...
...
@@ -38,10 +39,9 @@ def build_filter_ensemble(
filters
=
[]
for
function
,
kwargs
in
components
:
if
kwargs
is
None
:
f
=
get_filter
(
function
)()
else
:
# create a filter given its name in the registry
f
=
get_filter
(
function
)(
**
kwargs
)
# TODO: pass kwargs to filters properly
kwargs
=
{}
# create a filter given its name in the registry
f
=
partial
(
get_filter
(
function
),
**
kwargs
)
# add the filter as a pipeline step
filters
.
append
(
f
)
...
...
lm_eval/filters/selection.py
View file @
1554066c
...
...
@@ -17,12 +17,14 @@ class TakeFirstFilter(Filter):
class
TakeKFilter
(
Filter
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
**
kwargs
)
->
None
:
self
.
k
=
kwargs
.
pop
(
"k"
)
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
apply
(
self
,
resps
,
docs
):
# need resp to be subscriptable to check below
resps
=
list
(
resps
)
# check we have at least k responses per doc, else we can't take the first k
assert
(
len
(
resps
[
0
])
>=
self
.
k
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment