Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9454c839
Commit
9454c839
authored
May 02, 2021
by
Jason Phang
Browse files
gpt2 perplexity
parent
8846bec0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
231 additions
and
9 deletions
+231
-9
lm_eval/base.py
lm_eval/base.py
+95
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+35
-3
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+3
-3
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+5
-0
lm_eval/tasks/pile.py
lm_eval/tasks/pile.py
+49
-0
lm_eval/utils.py
lm_eval/utils.py
+43
-0
No files found.
lm_eval/base.py
View file @
9454c839
...
...
@@ -27,9 +27,51 @@ class LM(abc.ABC):
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `contination`
The log probability of `contin
u
ation`
isgreedy:
Whether `contination` would be generated by greedy sampling from `context`
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass
@
abc
.
abstractmethod
def
loglikelihood_perplexity
(
self
,
requests
):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
the max context length.
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementaitons
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: EOT
Max context length: 4
Resulting input/prediction pairs:
INPUT: EOT 0 1 2
PRED: 0 1 2 3
INPUT: 3 4 5 6
PRED: 4 5 6 7
INPUT: 5 6 7 8
PRED: 8 9
Observe that:
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list
A list of strings
string: str
String for which we are computing per-toke loglikelihood
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `continuation`
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
"""
pass
...
...
@@ -247,9 +289,60 @@ class MultipleChoiceTask(Task):
}
class
PerplexityTask
(
Task
,
abc
.
ABC
):
def
has_training_docs
(
self
):
return
False
def
fewshot_description
(
self
):
return
""
def
fewshot_examples
(
self
,
k
,
rnd
):
assert
k
==
0
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
,
rnd
):
assert
num_fewshot
==
0
assert
not
provide_description
return
""
def
higher_is_better
(
self
):
return
False
def
doc_to_text
(
self
,
doc
):
return
doc
def
doc_to_target
(
self
,
doc
):
raise
NotImplementedError
()
return
doc
def
construct_requests
(
self
,
doc
,
ctx
):
assert
not
ctx
req
=
rf
.
loglikelihood_perplexity
(
doc
)
return
req
def
process_results
(
self
,
doc
,
results
):
loglikelihood
,
=
results
return
{
"perplexity"
:
loglikelihood
,
}
def
aggregation
(
self
):
return
{
"perplexity"
:
self
.
compute_perplexity_from_loglikelihood
,
}
@
classmethod
def
compute_perplexity_from_loglikelihood
(
cls
,
loglikelihoods
):
aggregate_logprobs
=
np
.
concatenate
(
loglikelihoods
)
perplexity
=
np
.
exp
(
-
aggregate_logprobs
.
mean
())
return
float
(
perplexity
)
req_ret_lens
=
{
'loglikelihood'
:
2
,
'greedy_until'
:
None
,
'loglikelihood_perplexity'
:
1
,
}
import
os
...
...
lm_eval/evaluator.py
View file @
9454c839
...
...
@@ -34,7 +34,7 @@ def evaluate(lm, task_dict, provide_description, num_fewshot, limit):
task_docs
=
list
(
task_doc_func
())
rnd
=
random
.
Random
()
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
#
rnd.shuffle(task_docs)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
...
...
lm_eval/models/gpt2.py
View file @
9454c839
...
...
@@ -4,10 +4,13 @@ import torch.nn.functional as F
from
lm_eval.base
import
LM
from
lm_eval
import
utils
from
tqdm
import
tqdm
import
numpy
as
np
class
GPT2LM
(
LM
):
MAX_GEN_TOKS
=
256
VOCAB_SIZE
=
50257
EOT_TOKEN_ID
=
50256
def
__init__
(
self
,
device
=
None
,
pretrained
=
'gpt2'
):
super
().
__init__
()
...
...
@@ -39,7 +42,7 @@ class GPT2LM(LM):
for
context
,
continuation
in
requests
:
if
context
==
""
:
# end of text as context
context_enc
=
[
50256
]
context_enc
=
[
self
.
EOT_TOKEN_ID
]
else
:
context_enc
=
self
.
tokenizer
.
encode
(
context
)
...
...
@@ -49,6 +52,35 @@ class GPT2LM(LM):
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_perplexity
(
self
,
requests
):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods
=
[]
with
torch
.
no_grad
():
for
string
,
in
tqdm
(
requests
):
encoded
=
self
.
tokenizer
.
encode_plus
(
string
)[
"input_ids"
]
rolling_token_windows
=
utils
.
get_rolling_token_windows
(
token_list
=
encoded
,
prefix_token
=
self
.
EOT_TOKEN_ID
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
)
string_nll
=
[]
for
input_tokens
,
pred_tokens
in
rolling_token_windows
:
inp
=
torch
.
tensor
([
input_tokens
],
dtype
=
torch
.
long
).
to
(
self
.
device
)
labels
=
torch
.
tensor
([
pred_tokens
],
dtype
=
torch
.
long
).
to
(
self
.
device
)
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
][:,
:,
:
self
.
VOCAB_SIZE
],
dim
=-
1
)
# [batch, seq, vocab]
# Only score the relevant logits (i.e. the last len(pred_tokens) logits
scoring_logits
=
logits
[:,
-
len
(
pred_tokens
):].
reshape
(
len
(
pred_tokens
),
self
.
VOCAB_SIZE
)
string_nll
.
append
(
F
.
cross_entropy
(
scoring_logits
,
target
=
labels
.
view
(
-
1
),
reduction
=
"none"
).
cpu
().
numpy
()
)
string_nll
=
np
.
concatenate
(
string_nll
)
loglikelihoods
.
append
(
-
string_nll
)
return
loglikelihoods
def
_loglikelihood_tokens
(
self
,
requests
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
...
...
@@ -59,7 +91,7 @@ class GPT2LM(LM):
def
_collate
(
x
):
toks
=
x
[
1
]
+
x
[
2
]
return
(
len
(
toks
),
tuple
(
toks
))
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
cache_key
,
context_enc
,
continuation_enc
in
tqdm
(
reord
.
get_reordered
()):
# when too long to fit in context, truncate from the left
...
...
@@ -67,7 +99,7 @@ class GPT2LM(LM):
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
self
.
max_length
)
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
][:,
:,
:
50257
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
][:,
:,
:
self
.
VOCAB_SIZE
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
...
...
lm_eval/models/gpt3.py
View file @
9454c839
...
...
@@ -92,9 +92,9 @@ class GPT3LM(LM):
# we care about and so we need some kind of backup for when it isn't
toks
=
x
[
1
]
+
x
[
2
]
return
(
len
(
toks
),
tuple
(
toks
))
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
tqdm
(
list
(
utils
.
chunks
(
reord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))):
inps
=
[]
ctxlens
=
[]
...
...
@@ -121,7 +121,7 @@ class GPT3LM(LM):
# partial caching
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
return
reord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
...
...
lm_eval/tasks/__init__.py
View file @
9454c839
...
...
@@ -37,6 +37,7 @@ from . import hendrycks_test
from
.
import
hendrycks_math
from
.
import
cbt
from
.
import
lambada_cloze
from
.
import
pile
########################################
# Translation tasks
...
...
@@ -171,6 +172,10 @@ TASK_REGISTRY = {
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
"pile_enron"
:
pile
.
PileEnronPerplexityTask
,
"pile_ubuntu"
:
pile
.
PileUbuntuPerplexityTask
,
}
...
...
lm_eval/tasks/pile.py
0 → 100644
View file @
9454c839
import
os
import
lm_dataformat
import
abc
import
numpy
as
np
from
lm_eval.base
import
rf
,
PerplexityTask
from
..metrics
import
mean
,
matthews_corrcoef
,
f1_score
from
..utils
import
general_detokenize
from
best_download
import
download_file
class
PilePerplexityTask
(
PerplexityTask
,
abc
.
ABC
):
PILE_SET_NAME
=
None
VAL_PATH
=
'data/pile/val.jsonl.zst'
TEST_PATH
=
'data/pile/test.jsonl.zst'
def
download
(
self
):
os
.
makedirs
(
"data/pile/"
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
self
.
VAL_PATH
):
download_file
(
"https://the-eye.eu/public/AI/pile/val.jsonl.zst"
,
self
.
VAL_PATH
)
if
not
os
.
path
.
exists
(
self
.
TEST_PATH
):
download_file
(
"https://the-eye.eu/public/AI/pile/test.jsonl.zst"
,
self
.
TEST_PATH
)
def
validation_docs
(
self
):
rdr
=
lm_dataformat
.
Reader
(
self
.
VAL_PATH
)
for
doc
,
metadata
in
rdr
.
stream_data
(
get_meta
=
True
):
if
metadata
[
"pile_set_name"
]
==
self
.
PILE_SET_NAME
:
yield
doc
def
test_docs
(
self
):
rdr
=
lm_dataformat
.
Reader
(
self
.
TEST_PATH
)
for
doc
,
metadata
in
rdr
.
stream_data
(
get_meta
=
True
):
if
metadata
[
"pile_set_name"
]
==
self
.
PILE_SET_NAME
:
yield
doc
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
True
class
PileEnronPerplexityTask
(
PilePerplexityTask
):
PILE_SET_NAME
=
"Enron Emails"
class
PileUbuntuPerplexityTask
(
PilePerplexityTask
):
PILE_SET_NAME
=
"Ubuntu IRC"
lm_eval/utils.py
View file @
9454c839
...
...
@@ -61,6 +61,49 @@ def general_detokenize(string):
return
string
def
get_rolling_token_windows
(
token_list
,
prefix_token
,
max_seq_len
,
context_len
):
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
:param token_list: list
List of tokens to be PREDICTED
:param max_seq_len: int
max_seq_len of model (or max_seq_len we want to use)
:param context_len: int
Amount of desired token context for prediction. Needs to be at least 1.
:param prefix_token: token
Dummy token like <eos> so the first token has something to condition on
:return: generator
Generator of tuples
(input_tokens, pred_tokens)
Note: Score only the last len(pred_tokens) logits of the LM
"""
assert
1
<=
context_len
<=
max_seq_len
if
not
token_list
:
return
# +1 offset, going from input->preds
pred_len
=
max_seq_len
-
context_len
+
1
predicted
=
0
# Special handling for first window: predict all tokens
first_seq_len
=
min
(
max_seq_len
,
len
(
token_list
))
yield
(
[
prefix_token
]
+
token_list
[:
first_seq_len
-
1
],
token_list
[:
first_seq_len
]
)
predicted
+=
first_seq_len
while
predicted
<
len
(
token_list
):
window_pred_len
=
min
(
len
(
token_list
)
-
predicted
,
pred_len
)
window_end
=
predicted
+
window_pred_len
yield
(
token_list
[
window_end
-
max_seq_len
-
1
:
window_end
-
1
],
token_list
[
window_end
-
window_pred_len
:
window_end
],
)
predicted
+=
window_pred_len
class
Reorderer
:
def
__init__
(
self
,
arr
,
fn
):
self
.
size
=
len
(
arr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment