Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e41a082c
Commit
e41a082c
authored
Dec 27, 2020
by
Leo Gao
Browse files
Update
parent
76e65788
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
34 deletions
+81
-34
lm_eval/base.py
lm_eval/base.py
+16
-9
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+3
-2
lm_eval/tasks/superglue.py
lm_eval/tasks/superglue.py
+27
-15
main.py
main.py
+35
-8
No files found.
lm_eval/base.py
View file @
e41a082c
...
@@ -26,7 +26,7 @@ class LM(abc.ABC):
...
@@ -26,7 +26,7 @@ class LM(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
gen_
greedy
(
self
,
requests
):
def
greedy
_until
(
self
,
requests
):
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
:param requests: list
:param requests: list
...
@@ -104,7 +104,11 @@ class Dataset(abc.ABC):
...
@@ -104,7 +104,11 @@ class Dataset(abc.ABC):
return
random
.
sample
(
self
.
_traindocs
,
k
)
return
random
.
sample
(
self
.
_traindocs
,
k
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
,
include_target
=
True
):
def
doc_to_text
(
self
,
doc
):
pass
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -113,13 +117,14 @@ class Dataset(abc.ABC):
...
@@ -113,13 +117,14 @@ class Dataset(abc.ABC):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a dict with the following format:
"""Take a single document and the LM results and evaluates, returning a
list of dicts, each with the following format:
{
{
"submetric": str,
"submetric": str,
"value": float,
"value": float,
"higher_is_better": bool,
"higher_is_better": bool,
"aggregation": (
list
-> float),
"aggregation": (
[float]
-> float),
}
}
* `submetric` should be the name of the metric
* `submetric` should be the name of the metric
...
@@ -138,10 +143,12 @@ class Dataset(abc.ABC):
...
@@ -138,10 +143,12 @@ class Dataset(abc.ABC):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
):
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
):
raw_description
=
self
.
fewshot_description
()
raw_description
=
self
.
fewshot_description
()
description
=
(
raw_description
+
"
\n
===
\n\n
"
)
if
provide_description
and
raw_description
else
""
description
=
(
raw_description
+
"
\n
===
\n\n
"
)
if
provide_description
and
raw_description
else
""
labeled_examples
=
"
\n\n
"
.
join
(
labeled_examples
=
"
\n\n
"
.
join
(
map
(
self
.
doc_to_text
,
self
.
fewshot_examples
(
k
=
num_fewshot
)
)
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
self
.
fewshot_examples
(
k
=
num_fewshot
)
]
)
+
"
\n\n
"
)
+
"
\n\n
"
example
=
self
.
doc_to_text
(
doc
,
include_target
=
False
).
strip
()
example
=
self
.
doc_to_text
(
doc
).
strip
()
return
description
+
labeled_examples
+
example
return
description
+
labeled_examples
+
example
...
@@ -153,12 +160,12 @@ def median(arr):
...
@@ -153,12 +160,12 @@ def median(arr):
return
arr
[
len
(
arr
)
//
2
]
return
arr
[
len
(
arr
)
//
2
]
Request
=
collections
.
namedtuple
(
'Request'
,
(
'type'
,
'args'
,
'kwargs'
))
Request
=
collections
.
namedtuple
(
'Request'
,
(
'type'
,
'args'
))
class
RequestFactory
:
class
RequestFactory
:
def
__getattr__
(
self
,
attr
):
def
__getattr__
(
self
,
attr
):
def
fn
(
*
args
,
**
kwargs
):
def
fn
(
*
args
):
return
Request
(
attr
,
args
,
kwargs
)
return
Request
(
attr
,
args
)
return
fn
return
fn
...
...
lm_eval/models/gpt2.py
View file @
e41a082c
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lm_eval.base
import
LM
from
lm_eval.base
import
LM
from
lm_eval
import
utils
from
lm_eval
import
utils
from
tqdm
import
tqdm
class
GPT2LM
(
LM
):
class
GPT2LM
(
LM
):
...
@@ -20,7 +21,7 @@ class GPT2LM(LM):
...
@@ -20,7 +21,7 @@ class GPT2LM(LM):
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
res
=
[]
res
=
[]
# TODO: vectorize properly
# TODO: vectorize properly
for
context
,
continuation
in
requests
:
for
context
,
continuation
in
tqdm
(
requests
)
:
# when too long to fit in context, truncate from the left
# when too long to fit in context, truncate from the left
context_enc
=
self
.
tokenizer
.
encode
(
context
)
context_enc
=
self
.
tokenizer
.
encode
(
context
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
...
@@ -35,6 +36,6 @@ class GPT2LM(LM):
...
@@ -35,6 +36,6 @@ class GPT2LM(LM):
return
res
return
res
def
gen_
greedy
(
self
,
requests
):
def
greedy
_until
(
self
,
requests
):
# TODO: implement
# TODO: implement
pass
pass
\ No newline at end of file
lm_eval/tasks/superglue.py
View file @
e41a082c
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
auto
as
tqdm_lib
from
tqdm
import
auto
as
tqdm_lib
from
.
common
import
HFTask
,
simple_accuracy_metric
,
yesno
from
.
common
import
HFTask
,
simple_accuracy_metric
,
yesno
from
lm_eval.base
import
rf
,
mean
class
BoolQ
(
HFTask
):
class
BoolQ
(
HFTask
):
DATASET_PATH
=
"super_glue"
DATASET_PATH
=
"super_glue"
...
@@ -19,21 +19,33 @@ class BoolQ(HFTask):
...
@@ -19,21 +19,33 @@ class BoolQ(HFTask):
def
fewshot_description
(
self
):
def
fewshot_description
(
self
):
return
"Read the following passages and answer each question with a yes or a no."
return
"Read the following passages and answer each question with a yes or a no."
def
doc_to_text
(
self
,
doc
,
include_target
=
True
):
def
doc_to_text
(
self
,
doc
):
return
f
"
{
doc
[
'passage'
]
}
\n
question:
{
doc
[
'question'
]
}
\n
answer: "
\
return
f
"
{
doc
[
'passage'
]
}
\n
question:
{
doc
[
'question'
]
}
\n
answer: "
+
(
yesno
(
doc
[
'label'
])
if
include_target
else
""
)
def
doc_to_target
(
self
,
doc
):
return
yesno
(
doc
[
'label'
])
def
evaluate
(
self
,
docs
,
lm
,
provide_description
,
num_fewshot
):
def
construct_requests
(
self
,
ctx
):
golds
=
[
doc
[
"label"
]
for
doc
in
docs
]
preds
=
[]
ll_yes
=
rf
.
loglikelihood
(
ctx
,
' yes'
)
for
doc
in
docs
:
ll_no
=
rf
.
loglikelihood
(
ctx
,
' no'
)
ctx
=
self
.
fewshot_context
(
doc
=
doc
,
return
ll_yes
,
ll_no
provide_description
=
provide_description
,
num_fewshot
=
num_fewshot
,
def
process_results
(
self
,
doc
,
results
):
)
ll_yes
,
ll_no
=
results
preds
.
append
(
lm
.
loglikelihood
(
ctx
,
' yes'
)
>
lm
.
loglikelihood
(
ctx
,
' no'
))
gold
=
doc
[
"label"
]
return
simple_accuracy_metric
(
preds
=
preds
,
golds
=
golds
)
acc
=
1.
if
(
ll_yes
>
ll_no
)
==
gold
else
0.
return
[
{
"submetric"
:
"acc"
,
"value"
:
acc
"higher_is_better"
:
True
,
"aggregation"
:
mean
}
]
class
CommitmentBank
(
HFTask
):
class
CommitmentBank
(
HFTask
):
...
...
main.py
View file @
e41a082c
...
@@ -3,6 +3,7 @@ import json
...
@@ -3,6 +3,7 @@ import json
import
numpy
as
np
import
numpy
as
np
import
random
import
random
import
itertools
import
itertools
import
collections
from
lm_eval
import
models
,
tasks
from
lm_eval
import
models
,
tasks
...
@@ -30,17 +31,43 @@ def main():
...
@@ -30,17 +31,43 @@ def main():
else
:
else
:
task_names
=
args
.
tasks
.
split
(
","
)
task_names
=
args
.
tasks
.
split
(
","
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
task_dict_items
=
list
(
task_dict
.
items
())
results
=
{}
results
=
{}
for
task_name
,
task
in
task_dict
.
items
():
requests
=
collections
.
defaultdict
(
list
)
requests_lengths
=
collections
.
defaultdict
(
list
)
for
task_name
,
task
in
task_dict_items
:
# TODO: fall back to test docs
if
not
task
.
has_validation_docs
():
if
not
task
.
has_validation_docs
():
continue
continue
result
=
task
.
evaluate
(
docs
=
itertools
.
isslice
(
task
.
validation_docs
(),
0
,
args
.
limit
),
for
doc
in
itertools
.
islice
(
task
.
validation_docs
(),
0
,
args
.
limit
):
lm
=
lm
,
ctx
=
task
.
fewshot_context
(
provide_description
=
args
.
provide_description
,
doc
=
doc
,
num_fewshot
=
args
.
num_fewshot
,
provide_description
=
args
.
provide_description
,
)
num_fewshot
=
args
.
num_fewshot
,
results
[
task_name
]
=
result
)
reqs
=
task
.
construct_requests
(
ctx
)
lengths
=
collections
.
defaultdict
(
int
)
for
req
in
reqs
:
requests
[
req
.
type
].
append
(
req
)
lengths
[
req
.
type
]
+=
1
for
type
,
ct
in
lengths
.
items
():
requests_lengths
[
type
].
append
(
ct
)
# TODO: finish implementation
for
reqname
,
reqs
in
requests
.
items
():
lm_res
=
getattr
(
lm
,
reqname
)([
req
.
args
for
req
in
reqs
])
for
task_name
,
task
in
task_dict_items
:
if
not
task
.
has_validation_docs
():
continue
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
print
(
dumped
)
print
(
dumped
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment