Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
aab91285
Commit
aab91285
authored
Feb 09, 2021
by
Leo Gao
Browse files
Update gpt2 for efficiency and allow specifying model size
parent
4d8ed7d5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
28 deletions
+29
-28
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+29
-28
No files found.
lm_eval/models/gpt2.py
View file @
aab91285
...
@@ -7,44 +7,45 @@ from tqdm import tqdm
...
@@ -7,44 +7,45 @@ from tqdm import tqdm
class
GPT2LM
(
LM
):
class
GPT2LM
(
LM
):
def
__init__
(
self
,
device
=
"cpu"
):
def
__init__
(
self
,
device
=
"cpu"
,
pretrained
=
'gpt2'
):
self
.
device
=
torch
.
device
(
device
)
self
.
device
=
torch
.
device
(
device
)
self
.
gpt2
=
transformers
.
GPT2LMHeadModel
.
from_pretrained
(
'gpt2'
).
to
(
self
.
device
)
self
.
gpt2
=
transformers
.
GPT2LMHeadModel
.
from_pretrained
(
pretrained
).
to
(
self
.
device
)
self
.
gpt2
.
eval
()
self
.
gpt2
.
eval
()
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
'gpt2'
)
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
pretrained
)
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
@
classmethod
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
def
create_from_arg_string
(
cls
,
arg_string
):
args
=
utils
.
simple_parse_args_string
(
arg_string
)
args
=
utils
.
simple_parse_args_string
(
arg_string
)
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
))
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
)
,
pretrained
=
args
.
get
(
"pretrained"
,
"gpt2"
)
)
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
res
=
[]
res
=
[]
# TODO: vectorize properly
with
torch
.
no_grad
():
for
context
,
continuation
in
tqdm
(
requests
):
# TODO: vectorize properly
# when too long to fit in context, truncate from the left
for
context
,
continuation
in
tqdm
(
requests
):
# when too long to fit in context, truncate from the left
if
context
==
""
:
# end of text as context
if
context
==
""
:
context_enc
=
[
50256
]
# end of text as context
else
:
context_enc
=
[
50256
]
context_enc
=
self
.
tokenizer
.
encode
(
context
)
else
:
context_enc
=
self
.
tokenizer
.
encode
(
context
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [batch, seq]
res
.
append
((
float
(
logits
.
sum
()),
bool
(
max_equal
)))
res
.
append
((
float
(
logits
.
sum
()),
bool
(
max_equal
)))
return
res
return
res
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment