Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ee5467ff
Commit
ee5467ff
authored
May 03, 2021
by
Jason Phang
Browse files
adding tests sortof
parent
3a490624
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
231 additions
and
2 deletions
+231
-2
tests/test_evaluator.py
tests/test_evaluator.py
+12
-1
tests/test_models.py
tests/test_models.py
+17
-1
tests/test_utils.py
tests/test_utils.py
+202
-0
No files found.
tests/test_evaluator.py
View file @
ee5467ff
...
@@ -26,7 +26,18 @@ def test_evaluator(taskname, Task):
...
@@ -26,7 +26,18 @@ def test_evaluator(taskname, Task):
res
.
append
((
-
random
.
random
(),
False
))
res
.
append
((
-
random
.
random
(),
False
))
return
res
return
res
def
ll_perp_fn
(
reqs
):
for
string
,
in
reqs
:
assert
isinstance
(
string
,
str
)
res
=
[]
random
.
seed
(
42
)
for
_
in
reqs
:
res
.
append
((
-
random
.
random
(),))
return
res
lm
.
loglikelihood
=
ll_fn
lm
.
loglikelihood
=
ll_fn
lm
.
loglikelihood_perplexity
=
ll_perp_fn
evaluator
.
evaluate
(
lm
,
task_dict
,
False
,
0
,
10
)
evaluator
.
evaluate
(
lm
,
task_dict
,
False
,
0
,
10
)
tests/test_models.py
View file @
ee5467ff
...
@@ -34,4 +34,20 @@ def test_gpt2():
...
@@ -34,4 +34,20 @@ def test_gpt2():
targets
=
[
-
61.60536193847656
,
-
56.57843780517578
,
-
62.131004333496094
,
-
9.799489974975586
,
-
153.96334838867188
,
-
341.222900390625
,
-
731.1475830078125
,
-
61.60536193847656
,
-
8.682319641113281
]
targets
=
[
-
61.60536193847656
,
-
56.57843780517578
,
-
62.131004333496094
,
-
9.799489974975586
,
-
153.96334838867188
,
-
341.222900390625
,
-
731.1475830078125
,
-
61.60536193847656
,
-
8.682319641113281
]
for
(
pred
,
_
),
tgt
in
zip
(
vals
,
targets
):
for
(
pred
,
_
),
tgt
in
zip
(
vals
,
targets
):
assert
pred
==
pytest
.
approx
(
tgt
)
assert
pred
==
pytest
.
approx
(
tgt
)
\ No newline at end of file
def
test_gpt2_perplexity
():
gpt2
=
models
.
get_model
(
'gpt2'
).
create_from_arg_string
(
"device=cpu"
)
test_string
=
"We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity
=
gpt2
.
loglikelihood_perplexity
([(
test_string
,)])[
0
]
targets
=
[
-
4.9599953
,
-
8.069298
,
-
8.308624
,
-
10.178513
,
-
8.906924
,
-
1.9318912
,
-
7.745445
,
-
7.146077
,
-
5.2072
,
-
3.5882986
,
-
1.9957212
,
-
8.044922
,
-
0.20841774
,
-
5.1096807
,
-
0.099879116
,
-
8.888423
,
-
4.6180487
]
for
pred
,
tgt
in
zip
(
perplexity
,
targets
):
assert
pred
==
pytest
.
approx
(
tgt
)
# Hack: modify gpt2 to have shorter context length to induce rolling windows
gpt2
.
max_length
=
5
perplexity
=
gpt2
.
loglikelihood_perplexity
([(
test_string
,)])[
0
]
targets
=
[
-
4.96001
,
-
8.069275
,
-
8.308612
,
-
10.178482
,
-
8.90691
,
-
4.037338
,
-
8.09261
,
-
11.662385
,
-
10.206891
,
-
4.425003
,
-
2.2563353
,
-
7.909143
,
-
1.9304147
,
-
7.3610134
,
-
2.3120654
,
-
7.3229
,
-
2.1643813
]
for
pred
,
tgt
in
zip
(
perplexity
,
targets
):
assert
pred
==
pytest
.
approx
(
tgt
)
tests/test_utils.py
0 → 100644
View file @
ee5467ff
from
lm_eval.utils
import
get_rolling_token_windows
# noinspection DuplicatedCode
def
test_get_rolling_token_windows_v1
():
gold
=
[
([
-
100
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]),
([
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
],
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
]),
([
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
],
[
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
]),
([
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
],
[
30
,
31
,
32
,
33
]),
]
x
=
list
(
range
(
34
))
generator
=
get_rolling_token_windows
(
token_list
=
x
,
prefix_token
=-
100
,
max_seq_len
=
10
,
context_len
=
1
,
)
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
append
((
input_tokens
,
pred_tokens
))
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
# noinspection DuplicatedCode
def
test_get_rolling_token_windows_v2
():
gold
=
[
([
-
100
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]),
([
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
10
,
11
,
12
]),
([
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
[
13
,
14
,
15
]),
([
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
16
,
17
,
18
]),
([
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
],
[
19
,
20
,
21
]),
([
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
],
[
22
,
23
,
24
]),
([
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
],
[
25
,
26
,
27
]),
([
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[
28
,
29
,
30
]),
([
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
],
[
31
,
32
,
33
]),
]
x
=
list
(
range
(
34
))
generator
=
get_rolling_token_windows
(
token_list
=
x
,
prefix_token
=-
100
,
max_seq_len
=
10
,
context_len
=
8
,
)
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
append
((
input_tokens
,
pred_tokens
))
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
# noinspection DuplicatedCode
def
test_get_rolling_token_windows_v3
():
gold
=
[
([
-
100
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]),
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
10
]),
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
]),
([
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
]),
([
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
],
[
13
]),
([
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
],
[
14
]),
([
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
[
15
]),
([
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
[
16
]),
([
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
],
[
17
]),
([
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
18
]),
([
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
],
[
19
]),
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[
20
]),
([
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
],
[
21
]),
([
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
],
[
22
]),
([
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
],
[
23
]),
([
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
],
[
24
]),
([
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
],
[
25
]),
([
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
],
[
26
]),
([
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
],
[
27
]),
([
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
],
[
28
]),
([
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
],
[
29
]),
([
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
],
[
30
]),
([
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
],
[
31
]),
([
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
],
[
32
]),
([
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
],
[
33
]),
]
x
=
list
(
range
(
34
))
generator
=
get_rolling_token_windows
(
token_list
=
x
,
prefix_token
=-
100
,
max_seq_len
=
10
,
context_len
=
10
,
)
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
append
((
input_tokens
,
pred_tokens
))
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
# noinspection DuplicatedCode
def
test_get_rolling_token_windows_v4
():
gold
=
[
([
-
100
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]),
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
10
]),
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
11
]),
([
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
[
12
]),
([
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
],
[
13
]),
([
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
],
[
14
]),
([
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
[
15
]),
([
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
],
[
16
]),
([
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
],
[
17
]),
([
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
],
[
18
]),
([
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
],
[
19
]),
([
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
],
[
20
]),
([
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
],
[
21
]),
([
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
],
[
22
]),
([
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
],
[
23
]),
([
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
],
[
24
]),
([
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
],
[
25
]),
([
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
],
[
26
]),
([
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
],
[
27
]),
([
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
],
[
28
]),
([
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
],
[
29
]),
]
x
=
list
(
range
(
30
))
generator
=
get_rolling_token_windows
(
token_list
=
x
,
prefix_token
=-
100
,
max_seq_len
=
10
,
context_len
=
10
,
)
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
append
((
input_tokens
,
pred_tokens
))
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
# noinspection DuplicatedCode
def
test_get_rolling_token_windows_v5
():
gold
=
[
([
-
100
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]),
([
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
],
[
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
]),
([
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
],
[
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
]),
]
x
=
list
(
range
(
30
))
generator
=
get_rolling_token_windows
(
token_list
=
x
,
prefix_token
=-
100
,
max_seq_len
=
10
,
context_len
=
1
,
)
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
append
((
input_tokens
,
pred_tokens
))
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
# noinspection DuplicatedCode
def
test_get_rolling_token_windows_v6
():
gold
=
[
([
-
100
,
0
],
[
0
,
1
]),
([
1
,
2
],
[
2
,
3
]),
([
3
,
4
],
[
4
,
5
]),
([
5
,
6
],
[
6
,
7
]),
([
6
,
7
],
[
8
]),
]
x
=
list
(
range
(
9
))
generator
=
get_rolling_token_windows
(
token_list
=
x
,
prefix_token
=-
100
,
max_seq_len
=
2
,
context_len
=
1
,
)
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
append
((
input_tokens
,
pred_tokens
))
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
def
test_get_rolling_token_windows_empty
():
generator
=
get_rolling_token_windows
(
token_list
=
[],
prefix_token
=-
100
,
max_seq_len
=
2
,
context_len
=
1
,
)
n
=
0
for
_
in
generator
:
n
+=
1
assert
n
==
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment