Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e8f9dc71
Commit
e8f9dc71
authored
Feb 10, 2021
by
Leo Gao
Browse files
Implement GPT2 greedy_until
parent
9adf18b1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
6 deletions
+40
-6
lm_eval/base.py
lm_eval/base.py
+3
-3
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+29
-2
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+1
-0
tests/test_models.py
tests/test_models.py
+7
-1
No files found.
lm_eval/base.py
View file @
e8f9dc71
...
@@ -38,9 +38,9 @@ class LM(abc.ABC):
...
@@ -38,9 +38,9 @@ class LM(abc.ABC):
A list of pairs (context, until)
A list of pairs (context, until)
context: str
context: str
Context string
Context string
until: str
until:
[
str
]
The string sequence to generate until. Th
is
string sequence
may
The string sequence
s
to generate until. Th
ese
string sequence
s
span across multiple tokens, or may be part of one token.
may each
span across multiple tokens, or may be part of one token.
:return: list
:return: list
A list of strings continuation
A list of strings continuation
continuation: str
continuation: str
...
...
lm_eval/models/gpt2.py
View file @
e8f9dc71
...
@@ -7,6 +7,8 @@ from tqdm import tqdm
...
@@ -7,6 +7,8 @@ from tqdm import tqdm
class
GPT2LM
(
LM
):
class
GPT2LM
(
LM
):
MAX_GEN_TOKS
=
256
def
__init__
(
self
,
device
=
"cpu"
,
pretrained
=
'gpt2'
):
def
__init__
(
self
,
device
=
"cpu"
,
pretrained
=
'gpt2'
):
self
.
device
=
torch
.
device
(
device
)
self
.
device
=
torch
.
device
(
device
)
self
.
gpt2
=
transformers
.
GPT2LMHeadModel
.
from_pretrained
(
pretrained
).
to
(
self
.
device
)
self
.
gpt2
=
transformers
.
GPT2LMHeadModel
.
from_pretrained
(
pretrained
).
to
(
self
.
device
)
...
@@ -23,6 +25,7 @@ class GPT2LM(LM):
...
@@ -23,6 +25,7 @@ class GPT2LM(LM):
res
=
[]
res
=
[]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# TODO: vectorize properly
# TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
for
context
,
continuation
in
tqdm
(
requests
):
for
context
,
continuation
in
tqdm
(
requests
):
# when too long to fit in context, truncate from the left
# when too long to fit in context, truncate from the left
...
@@ -50,5 +53,29 @@ class GPT2LM(LM):
...
@@ -50,5 +53,29 @@ class GPT2LM(LM):
return
res
return
res
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
# TODO: implement
# TODO: implement fully general `until` that handles untils that are
pass
# multiple tokens or that span multiple tokens correctly
res
=
[]
for
context
,
until
in
tqdm
(
requests
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
context_enc
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
context
)]).
to
(
self
.
device
)
primary_until
,
=
self
.
tokenizer
.
encode
(
until
[
0
])
cont
=
self
.
gpt2
.
generate
(
context_enc
,
max_length
=
self
.
MAX_GEN_TOKS
,
eos_token_id
=
primary_until
,
do_sample
=
False
)
s
=
self
.
tokenizer
.
decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
res
.
append
(
s
)
return
res
lm_eval/models/gpt3.py
View file @
e8f9dc71
...
@@ -113,6 +113,7 @@ class GPT3LM(LM):
...
@@ -113,6 +113,7 @@ class GPT3LM(LM):
max_tokens
=
self
.
MAX_GEN_TOKS
,
max_tokens
=
self
.
MAX_GEN_TOKS
,
temperature
=
0.
,
temperature
=
0.
,
logprobs
=
10
,
logprobs
=
10
,
stop
=
until
)
)
res
.
append
(
response
.
choices
[
0
][
'text'
])
res
.
append
(
response
.
choices
[
0
][
'text'
])
...
...
tests/test_models.py
View file @
e8f9dc71
...
@@ -12,4 +12,10 @@ def test_gpt2():
...
@@ -12,4 +12,10 @@ def test_gpt2():
assert
not
ig_cat
assert
not
ig_cat
# test empty context
# test empty context
gpt2
.
loglikelihood
([(
''
,
'test'
)])
gpt2
.
loglikelihood
([(
''
,
'test'
)])
\ No newline at end of file
gen
,
=
gpt2
.
greedy_until
([
(
'The quick brown fox jumps over the lazy'
,
[
'.'
,
'
\n
'
])
])
assert
gen
==
', lazy fox and they both fall to the ground'
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment