Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8ebe36b2
Commit
8ebe36b2
authored
Dec 21, 2021
by
Jonathan Tow
Browse files
Add positional arg deprecation decorator
parent
d34ae3cf
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
85 additions
and
9 deletions
+85
-9
lm_eval/base.py
lm_eval/base.py
+1
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+11
-1
lm_eval/tasks/prost.py
lm_eval/tasks/prost.py
+7
-1
lm_eval/tasks/truthfulqa.py
lm_eval/tasks/truthfulqa.py
+13
-2
lm_eval/utils.py
lm_eval/utils.py
+16
-1
scripts/cost_estimate.py
scripts/cost_estimate.py
+9
-1
tests/test_evaluator.py
tests/test_evaluator.py
+18
-2
tests/test_version_stable.py
tests/test_version_stable.py
+10
-1
No files found.
lm_eval/base.py
View file @
8ebe36b2
...
...
@@ -457,6 +457,7 @@ class Task(abc.ABC):
DeprecationWarning
)
return
""
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
,
rnd
,
description
=
None
):
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
...
...
lm_eval/evaluator.py
View file @
8ebe36b2
...
...
@@ -6,8 +6,10 @@ import lm_eval.models
import
lm_eval.tasks
import
lm_eval.base
import
numpy
as
np
from
lm_eval.utils
import
positional_deprecated
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
,
task_names
,
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
...
...
@@ -51,7 +53,14 @@ def simple_evaluate(model, model_args, task_names,
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
task_names
)
results
=
evaluate
(
lm
,
task_dict
,
False
,
num_fewshot
,
limit
,
description_dict
=
description_dict
)
results
=
evaluate
(
lm
=
lm
,
task_dict
=
task_dict
,
provide_description
=
False
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
description_dict
=
description_dict
)
# add info about the model and few shot config
results
[
"config"
]
=
{
...
...
@@ -69,6 +78,7 @@ def simple_evaluate(model, model_args, task_names,
return
results
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
,
num_fewshot
,
limit
,
bootstrap_iters
=
100000
,
description_dict
=
None
):
"""Instantiate and evaluate a model on a list of tasks.
...
...
lm_eval/tasks/prost.py
View file @
8ebe36b2
...
...
@@ -38,7 +38,13 @@ class PROST(HFTask, MultipleChoiceTask):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
,
rnd
,
description
=
None
):
assert
num_fewshot
==
0
,
'PROST is designed to probe models in a zero-shot fashion only.'
return
super
().
fewshot_context
(
doc
,
num_fewshot
,
provide_description
,
rnd
,
description
)
return
super
().
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
provide_description
=
provide_description
,
rnd
=
rnd
,
description
=
description
)
def
_convert_standard
(
self
,
doc
):
out_doc
=
{
...
...
lm_eval/tasks/truthfulqa.py
View file @
8ebe36b2
...
...
@@ -87,7 +87,13 @@ class TruthfulQAMultipleChoice(Task):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
,
rnd
,
description
=
None
):
assert
num_fewshot
==
0
,
"TruthfulQA is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
doc
,
num_fewshot
,
provide_description
,
rnd
,
description
)
return
super
().
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
provide_description
=
provide_description
,
rnd
=
rnd
,
description
=
description
)
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
...
...
@@ -219,7 +225,12 @@ class TruthfulQAGeneration(Task):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
,
rnd
,
description
=
None
):
assert
num_fewshot
==
0
,
"TruthfulQA is intended only for the zero-shot setting."
return
super
().
fewshot_context
(
doc
,
num_fewshot
,
provide_description
,
rnd
,
description
)
return
super
().
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
provide_description
=
provide_description
,
rnd
=
rnd
,
description
=
description
)
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
...
...
lm_eval/utils.py
View file @
8ebe36b2
import
os
import
re
import
collections
import
functools
class
ExitCodeError
(
Exception
):
...
...
@@ -138,4 +139,18 @@ class Reorderer:
assert
all
(
cov
)
return
res
\ No newline at end of file
return
res
def
positional_deprecated
(
fn
):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
"""
@
functools
.
wraps
(
fn
)
def
_wrapper
(
*
args
,
**
kwargs
):
if
len
(
args
)
!=
0
:
print
(
f
"WARNING: using
{
fn
.
__name__
}
with positional arguments is "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!"
)
return
fn
(
*
args
,
**
kwargs
)
return
_wrapper
scripts/cost_estimate.py
View file @
8ebe36b2
...
...
@@ -51,7 +51,15 @@ def main():
values
=
[]
for
taskname
in
task_list
.
split
(
","
):
lm
.
tokencost
=
0
evaluator
.
evaluate
(
lm
,
{
taskname
:
tasks
.
get_task
(
taskname
)()},
False
,
0
,
None
,
bootstrap_iters
=
10
)
evaluator
.
evaluate
(
lm
=
lm
,
task_dict
=
{
taskname
:
tasks
.
get_task
(
taskname
)()},
provide_description
=
False
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
10
,
description_dict
=
None
)
print
(
taskname
,
lm
.
tokencost
)
values
.
append
([
taskname
,
lm
.
tokencost
,
lm
.
tokencost
/
1000
*
0.0008
,
lm
.
tokencost
/
1000
*
0.0012
,
lm
.
tokencost
/
1000
*
0.006
,
lm
.
tokencost
/
1000
*
0.06
])
...
...
tests/test_evaluator.py
View file @
8ebe36b2
...
...
@@ -48,8 +48,24 @@ def test_evaluator(taskname, task_class):
lm
.
loglikelihood_rolling
=
ll_perp_fn
limit
=
10
e1
=
evaluator
.
evaluate
(
lm
,
task_dict
,
False
,
0
,
limit
,
bootstrap_iters
=
10
,
description_dict
=
None
)
e2
=
evaluator
.
evaluate
(
lm
,
task_dict
,
False
,
0
,
limit
,
bootstrap_iters
=
10
,
description_dict
=
None
)
e1
=
evaluator
.
evaluate
(
lm
=
lm
,
task_dict
=
task_dict
,
provide_description
=
False
,
num_fewshot
=
0
,
limit
=
limit
,
bootstrap_iters
=
10
,
description_dict
=
None
)
e2
=
evaluator
.
evaluate
(
lm
=
lm
,
task_dict
=
task_dict
,
provide_description
=
False
,
num_fewshot
=
0
,
limit
=
limit
,
bootstrap_iters
=
10
,
description_dict
=
None
)
# check that caching is working
assert
e1
==
e2
tests/test_version_stable.py
View file @
8ebe36b2
...
...
@@ -99,5 +99,14 @@ def test_versions_stable(taskname, task_class):
lm
.
greedy_until
=
greedy_until
limit
=
None
result
=
evaluator
.
evaluate
(
lm
,
task_dict
,
False
,
0
,
limit
,
bootstrap_iters
=
10
)
result
=
evaluator
.
evaluate
(
lm
=
lm
,
task_dict
=
task_dict
,
provide_description
=
False
,
num_fewshot
=
0
,
limit
=
limit
,
bootstrap_iters
=
10
,
description_dict
=
None
)
assert_target
(
f
"
{
taskname
}
-v
{
task_class
.
VERSION
}
-res"
,
result
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment