Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b0cf0163
"ml/vscode:/vscode.git/clone" did not exist on "3b96a93672377129f2a2aafc447e79ef1ca48c5f"
Commit
b0cf0163
authored
May 04, 2021
by
Leo Gao
Browse files
Begin refactoring perplexity code
parent
ee5467ff
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
57 additions
and
26 deletions
+57
-26
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+8
-0
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+31
-19
lm_eval/utils.py
lm_eval/utils.py
+7
-0
tests/test_evaluator.py
tests/test_evaluator.py
+1
-0
tests/test_models.py
tests/test_models.py
+4
-6
tests/test_utils.py
tests/test_utils.py
+6
-1
No files found.
lm_eval/models/dummy.py
View file @
b0cf0163
...
...
@@ -26,3 +26,11 @@ class DummyLM(LM):
assert
ctx
.
strip
()
!=
''
return
res
def
loglikelihood_perplexity
(
self
,
requests
):
res
=
[]
for
_
in
requests
:
res
.
append
(
-
random
.
random
())
return
res
\ No newline at end of file
lm_eval/models/gpt2.py
View file @
b0cf0163
...
...
@@ -60,24 +60,23 @@ class GPT2LM(LM):
with
torch
.
no_grad
():
for
string
,
in
tqdm
(
requests
):
encoded
=
self
.
tokenizer
.
encode_plus
(
string
)[
"input_ids"
]
rolling_token_windows
=
utils
.
get_rolling_token_windows
(
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
encoded
,
prefix_token
=
self
.
EOT_TOKEN_ID
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
)
string_nll
=
[]
for
input_tokens
,
pred_tokens
in
rolling_token_windows
:
inp
=
torch
.
tensor
([
input_tokens
],
dtype
=
torch
.
long
).
to
(
self
.
device
)
labels
=
torch
.
tensor
([
pred_tokens
],
dtype
=
torch
.
long
).
to
(
self
.
device
)
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
][:,
:,
:
self
.
VOCAB_SIZE
],
dim
=-
1
)
# [batch, seq, vocab]
# Only score the relevant logits (i.e. the last len(pred_tokens) logits
scoring_logits
=
logits
[:,
-
len
(
pred_tokens
):].
reshape
(
len
(
pred_tokens
),
self
.
VOCAB_SIZE
)
string_nll
.
append
(
F
.
cross_entropy
(
scoring_logits
,
target
=
labels
.
view
(
-
1
),
reduction
=
"none"
).
cpu
().
numpy
()
)
string_nll
=
np
.
concatenate
(
string_nll
)
loglikelihoods
.
append
(
-
string_nll
)
)))
# todo: figure out partial caching
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
)
# discard is_greedy
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
string_nll
=
sum
(
string_nll
)
loglikelihoods
.
append
(
string_nll
)
return
loglikelihoods
...
...
@@ -94,12 +93,25 @@ class GPT2LM(LM):
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
cache_key
,
context_enc
,
continuation_enc
in
tqdm
(
reord
.
get_reordered
()):
assert
len
(
context_enc
)
>
0
assert
len
(
continuation_enc
)
>
0
assert
len
(
continuation_enc
)
<=
self
.
max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
self
.
max_length
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
self
.
max_length
)
inp
=
torch
.
tensor
([
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):]
],
dtype
=
torch
.
long
).
to
(
self
.
device
)
cont_toks
=
inp
[:,
-
len
(
continuation_enc
):]
# [batch, seq]
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
][:,
:,
:
self
.
VOCAB_SIZE
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
[:,
:
-
1
])[
0
][:,
-
len
(
continuation_enc
):,
:
self
.
VOCAB_SIZE
],
dim
=-
1
)
# [batch, seq, vocab] - vocab size is clipped to exclude padding tokens or whatever
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
...
...
@@ -108,7 +120,7 @@ class GPT2LM(LM):
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
answer
=
(
float
(
logits
.
cpu
().
to
(
torch
.
float64
).
sum
()),
bool
(
max_equal
))
# partial caching
if
cache_key
is
not
None
:
...
...
lm_eval/utils.py
View file @
b0cf0163
...
...
@@ -97,12 +97,19 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
while
predicted
<
len
(
token_list
):
window_pred_len
=
min
(
len
(
token_list
)
-
predicted
,
pred_len
)
window_end
=
predicted
+
window_pred_len
yield
(
token_list
[
window_end
-
max_seq_len
-
1
:
window_end
-
1
],
token_list
[
window_end
-
window_pred_len
:
window_end
],
)
predicted
+=
window_pred_len
def
make_disjoint_window
(
pair
):
""" Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """
a
,
b
=
pair
return
a
[:
-
(
len
(
b
)
-
1
)],
b
class
Reorderer
:
def
__init__
(
self
,
arr
,
fn
):
...
...
tests/test_evaluator.py
View file @
b0cf0163
...
...
@@ -15,6 +15,7 @@ def test_evaluator(taskname, Task):
def
ll_fn
(
reqs
):
for
ctx
,
cont
in
reqs
:
if
len
(
ctx
)
==
0
:
continue
# space convention
assert
ctx
[
-
1
]
!=
' '
assert
cont
[
0
]
==
' '
or
ctx
[
-
1
]
==
'
\n
'
...
...
tests/test_models.py
View file @
b0cf0163
...
...
@@ -41,13 +41,11 @@ def test_gpt2_perplexity():
gpt2
=
models
.
get_model
(
'gpt2'
).
create_from_arg_string
(
"device=cpu"
)
test_string
=
"We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity
=
gpt2
.
loglikelihood_perplexity
([(
test_string
,)])[
0
]
targets
=
[
-
4.9599953
,
-
8.069298
,
-
8.308624
,
-
10.178513
,
-
8.906924
,
-
1.9318912
,
-
7.745445
,
-
7.146077
,
-
5.2072
,
-
3.5882986
,
-
1.9957212
,
-
8.044922
,
-
0.20841774
,
-
5.1096807
,
-
0.099879116
,
-
8.888423
,
-
4.6180487
]
for
pred
,
tgt
in
zip
(
perplexity
,
targets
):
assert
pred
==
pytest
.
approx
(
tgt
)
tgt
=
sum
([
-
4.9599953
,
-
8.069298
,
-
8.308624
,
-
10.178513
,
-
8.906924
,
-
1.9318912
,
-
7.745445
,
-
7.146077
,
-
5.2072
,
-
3.5882986
,
-
1.9957212
,
-
8.044922
,
-
0.20841774
,
-
5.1096807
,
-
0.099879116
,
-
8.888423
,
-
4.6180487
])
assert
perplexity
==
pytest
.
approx
(
tgt
,
abs
=
1e-3
)
# Hack: modify gpt2 to have shorter context length to induce rolling windows
gpt2
.
max_length
=
5
perplexity
=
gpt2
.
loglikelihood_perplexity
([(
test_string
,)])[
0
]
targets
=
[
-
4.96001
,
-
8.069275
,
-
8.308612
,
-
10.178482
,
-
8.90691
,
-
4.037338
,
-
8.09261
,
-
11.662385
,
-
10.206891
,
-
4.425003
,
-
2.2563353
,
-
7.909143
,
-
1.9304147
,
-
7.3610134
,
-
2.3120654
,
-
7.3229
,
-
2.1643813
]
for
pred
,
tgt
in
zip
(
perplexity
,
targets
):
assert
pred
==
pytest
.
approx
(
tgt
)
tgt
=
sum
([
-
4.96001
,
-
8.069275
,
-
8.308612
,
-
10.178482
,
-
8.90691
,
-
4.037338
,
-
8.09261
,
-
11.662385
,
-
10.206891
,
-
4.425003
,
-
2.2563353
,
-
7.909143
,
-
1.9304147
,
-
7.3610134
,
-
2.3120654
,
-
7.3229
,
-
2.1643813
])
assert
perplexity
==
pytest
.
approx
(
tgt
,
abs
=
1e-3
)
tests/test_utils.py
View file @
b0cf0163
from
lm_eval.utils
import
get_rolling_token_windows
from
lm_eval.utils
import
get_rolling_token_windows
,
make_disjoint_window
# noinspection DuplicatedCode
...
...
@@ -200,3 +200,8 @@ def test_get_rolling_token_windows_empty():
for
_
in
generator
:
n
+=
1
assert
n
==
0
def
test_make_disjoint_window
():
assert
make_disjoint_window
(([
1
,
2
,
3
,
4
,
5
],
[
2
,
3
,
4
,
5
,
6
]))
==
([
1
],
[
2
,
3
,
4
,
5
,
6
])
assert
make_disjoint_window
(([
1
,
2
,
3
,
4
,
5
],
[
4
,
5
,
6
]))
==
([
1
,
2
,
3
],
[
4
,
5
,
6
])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment