Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7b649ded
"src/vscode:/vscode.git/clone" did not exist on "b3086ac2606d4b6999788f7faf06afa30406e44e"
Commit
7b649ded
authored
Feb 10, 2021
by
Leo Gao
Browse files
Fixes to make greedy_until work
parent
eb4c8407
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
44 additions
and
8 deletions
+44
-8
lm_eval/base.py
lm_eval/base.py
+5
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-0
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+6
-2
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+26
-2
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+1
-1
lm_eval/tasks/squad.py
lm_eval/tasks/squad.py
+4
-2
tests/test_evaluator.py
tests/test_evaluator.py
+1
-1
No files found.
lm_eval/base.py
View file @
7b649ded
...
@@ -269,6 +269,7 @@ def perplexity(items):
...
@@ -269,6 +269,7 @@ def perplexity(items):
req_ret_lens
=
{
req_ret_lens
=
{
'loglikelihood'
:
2
,
'loglikelihood'
:
2
,
'greedy_until'
:
None
,
}
}
import
os
import
os
...
@@ -335,11 +336,15 @@ class Request:
...
@@ -335,11 +336,15 @@ class Request:
self
.
index
=
index
self
.
index
=
index
def
__iter__
(
self
):
def
__iter__
(
self
):
if
req_ret_lens
[
self
.
type
]
is
None
:
raise
IndexError
(
'This request type does not return multiple arguments!'
)
i
=
0
i
=
0
for
i
in
range
(
req_ret_lens
[
self
.
type
]):
for
i
in
range
(
req_ret_lens
[
self
.
type
]):
yield
Request
(
self
.
type
,
self
.
args
,
i
)
yield
Request
(
self
.
type
,
self
.
args
,
i
)
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
if
req_ret_lens
[
self
.
type
]
is
None
:
raise
IndexError
(
'This request type does not return multiple arguments!'
)
return
Request
(
self
.
type
,
self
.
args
,
i
)
return
Request
(
self
.
type
,
self
.
args
,
i
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
...
lm_eval/evaluator.py
View file @
7b649ded
...
@@ -39,6 +39,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
...
@@ -39,6 +39,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
)
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
requests
[
req
.
type
].
append
(
req
)
requests
[
req
.
type
].
append
(
req
)
...
...
lm_eval/models/dummy.py
View file @
7b649ded
...
@@ -19,5 +19,9 @@ class DummyLM(LM):
...
@@ -19,5 +19,9 @@ class DummyLM(LM):
return
res
return
res
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
# TODO: implement
res
=
[]
pass
for
_
in
requests
:
res
.
append
(
"lol"
)
return
res
lm_eval/models/gpt2.py
View file @
7b649ded
...
@@ -49,5 +49,29 @@ class GPT2LM(LM):
...
@@ -49,5 +49,29 @@ class GPT2LM(LM):
return
res
return
res
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
# TODO: implement
# TODO: implement fully general `until` that handles untils that are
pass
# multiple tokens or that span multiple tokens correctly
res
=
[]
for
context
,
until
in
tqdm
(
requests
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
context_enc
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
context
)]).
to
(
self
.
device
)
primary_until
,
=
self
.
tokenizer
.
encode
(
until
[
0
])
cont
=
self
.
gpt2
.
generate
(
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
self
.
MAX_GEN_TOKS
,
eos_token_id
=
primary_until
,
do_sample
=
False
)
s
=
self
.
tokenizer
.
decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
res
.
append
(
s
)
return
res
lm_eval/tasks/arithmetic.py
View file @
7b649ded
...
@@ -63,7 +63,7 @@ class Arithmetic(Task):
...
@@ -63,7 +63,7 @@ class Arithmetic(Task):
return
is_prediction
return
is_prediction
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
ll
,
is_prediction
=
results
is_prediction
,
=
results
return
{
return
{
"acc"
:
is_prediction
"acc"
:
is_prediction
}
}
...
...
lm_eval/tasks/squad.py
View file @
7b649ded
...
@@ -26,7 +26,7 @@ class SQuAD(HFTask):
...
@@ -26,7 +26,7 @@ class SQuAD(HFTask):
return
""
return
""
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
return
'Title: '
+
doc
[
'title'
]
+
'
\n\n
'
+
'Background: '
+
doc
[
'context'
]
+
'
\n\n
'
+
'Q: '
+
doc
[
'question'
]
+
'
\n\n
'
+
'A:'
return
'Title: '
+
doc
[
'title'
]
+
'
\n\n
'
+
'Background: '
+
doc
[
'context'
]
+
'
\n\n
'
+
'Q
uestion
: '
+
doc
[
'question'
]
+
'
\n\n
'
+
'A
nswer
:'
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
answer_list
=
doc
[
'answers'
][
'text'
]
answer_list
=
doc
[
'answers'
][
'text'
]
...
@@ -62,9 +62,11 @@ class SQuAD(HFTask):
...
@@ -62,9 +62,11 @@ class SQuAD(HFTask):
"""
"""
squad_metric
=
datasets
.
load_metric
(
"squad_v2"
)
squad_metric
=
datasets
.
load_metric
(
"squad_v2"
)
continuation
,
=
results
predictions
=
{
predictions
=
{
'id'
:
doc
[
'id'
],
'id'
:
doc
[
'id'
],
'prediction_text'
:
results
[
0
]
,
'prediction_text'
:
continuation
,
}
}
references
=
{
references
=
{
...
...
tests/test_evaluator.py
View file @
7b649ded
...
@@ -8,7 +8,7 @@ import pytest
...
@@ -8,7 +8,7 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration
# TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces
# test once we break evaluator into smaller, more manageable pieces
@
pytest
.
mark
.
parametrize
(
"taskname,Task"
,
tasks
.
TASK_REGISTRY
.
items
()
)
@
pytest
.
mark
.
parametrize
(
"taskname,Task"
,
[(
'squad'
,
tasks
.
squad
.
SQuAD
)]
)
def
test_evaluator
(
taskname
,
Task
):
def
test_evaluator
(
taskname
,
Task
):
task_dict
=
tasks
.
get_task_dict
([
taskname
])
task_dict
=
tasks
.
get_task_dict
([
taskname
])
lm
=
models
.
get_model
(
'dummy'
)()
lm
=
models
.
get_model
(
'dummy'
)()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment