Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1e7f884d
".github/vscode:/vscode.git/clone" did not exist on "19e5a890f70b95a55c9de6a55357d78fc0a4ff81"
Commit
1e7f884d
authored
May 05, 2021
by
Leo Gao
Browse files
Refactor PerplexityTask
parent
b0cf0163
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
12 deletions
+28
-12
lm_eval/base.py
lm_eval/base.py
+21
-11
lm_eval/metrics.py
lm_eval/metrics.py
+5
-0
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+2
-1
No files found.
lm_eval/base.py
View file @
1e7f884d
import
abc
import
random
import
numpy
as
np
import
re
from
lm_eval.metrics
import
mean
from
lm_eval.metrics
import
mean
,
perplexity
,
weighted_mean
class
LM
(
abc
.
ABC
):
...
...
@@ -307,14 +308,17 @@ class PerplexityTask(Task, abc.ABC):
return
""
def
higher_is_better
(
self
):
return
False
return
{
"word_perplexity"
:
False
,
"byte_perplexity"
:
False
,
"bits_per_byte"
:
False
,
}
def
doc_to_text
(
self
,
doc
):
return
doc
def
doc_to_target
(
self
,
doc
):
raise
NotImplementedError
()
return
doc
def
construct_requests
(
self
,
doc
,
ctx
):
assert
not
ctx
...
...
@@ -324,20 +328,26 @@ class PerplexityTask(Task, abc.ABC):
def
process_results
(
self
,
doc
,
results
):
loglikelihood
,
=
results
return
{
"perplexity"
:
loglikelihood
,
"word_perplexity"
:
loglikelihood
/
self
.
count_words
(
self
.
doc_to_text
(
doc
)),
"byte_perplexity"
:
loglikelihood
/
self
.
count_bytes
(
self
.
doc_to_text
(
doc
)),
"bits_per_byte"
:
(
-
loglikelihood
,
self
.
count_bytes
(
self
.
doc_to_text
(
doc
)))
}
def
aggregation
(
self
):
return
{
"perplexity"
:
self
.
compute_perplexity_from_loglikelihood
,
"word_perplexity"
:
perplexity
,
"byte_perplexity"
:
perplexity
,
"bits_per_byte"
:
weighted_mean
}
@
classmethod
def
compute_perplexity_from_loglikelihood
(
cls
,
loglikelihoods
):
aggregate_logprobs
=
np
.
concatenate
(
loglikelihoods
)
perplexity
=
np
.
exp
(
-
aggregate_logprobs
.
mean
())
return
float
(
perplexity
)
def
count_bytes
(
self
,
s
):
return
len
(
s
.
encode
(
"utf-8"
))
def
count_words
(
self
,
s
):
""" Downstream tasks with custom word boundaries should override this! """
return
len
(
re
.
split
(
r
"\s+"
,
s
))
def
req_ret_lens
=
{
'loglikelihood'
:
2
,
...
...
lm_eval/metrics.py
View file @
1e7f884d
...
...
@@ -62,6 +62,11 @@ def perplexity(items):
return
math
.
exp
(
-
mean
(
items
))
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
...
...
lm_eval/models/gpt2.py
View file @
1e7f884d
...
...
@@ -60,6 +60,7 @@ class GPT2LM(LM):
with
torch
.
no_grad
():
for
string
,
in
tqdm
(
requests
):
encoded
=
self
.
tokenizer
.
encode_plus
(
string
)[
"input_ids"
]
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
encoded
,
prefix_token
=
self
.
EOT_TOKEN_ID
,
...
...
@@ -67,9 +68,9 @@ class GPT2LM(LM):
context_len
=
1
,
)))
# todo: figure out partial caching
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
)
# discard is_greedy
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment