Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
12c2ee1e
Commit
12c2ee1e
authored
Dec 23, 2020
by
Leo Gao
Browse files
Merge branch 'master' of github.com:EleutherAI/lm_evaluation_harness
parents
cf69ba9c
61ff104e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
89 deletions
+19
-89
lm_eval/base.py
lm_eval/base.py
+5
-29
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+0
-3
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+5
-20
lm_eval/models/gpt3.py
lm_eval/models/gpt3.py
+9
-37
No files found.
lm_eval/base.py
View file @
12c2ee1e
...
...
@@ -3,44 +3,20 @@ import random
class
LM
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
generate
(
self
,
context
,
max_gen_length
):
"""Conditional text generation with an LM
:param context: str
Context string for conditional generation
:param max_gen_length: int
Maximum number of tokens to generate
:return: str
"""
pass
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
context
,
continuation
):
"""Compute log-likelihood of a generation a continuation from a context
Assume that the final text will simple be
context + continuation
"""Compute log-likelihood of generating a continuation from a context
:param context: str
Context string
for conditional generation
Context string
:param continuation: str
Maximum number of tokens to generate
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: float
"""
pass
@
classmethod
def
num_tokens
(
cls
,
string
):
"""Return the number of tokens in a string, based on tokenization
:param string: str
Input string
:return: int
"""
pass
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
"""Constructor method, in case models need additional arguments
...
...
lm_eval/models/dummy.py
View file @
12c2ee1e
...
...
@@ -5,8 +5,5 @@ from . import MODEL_REGISTRY
@
MODEL_REGISTRY
.
register
(
"dummy"
)
class
DummyLM
(
LM
):
def
generate
(
self
,
context
,
max_gen_length
):
return
"lol"
def
loglikelihood
(
self
,
context
,
continuation
):
return
0.0
lm_eval/models/gpt2.py
View file @
12c2ee1e
...
...
@@ -10,36 +10,21 @@ class GPT2LM(LM):
self
.
device
=
torch
.
device
(
device
)
self
.
gpt2
=
transformers
.
GPT2LMHeadModel
.
from_pretrained
(
'gpt2'
).
to
(
self
.
device
)
self
.
gpt2
.
eval
()
self
.
tokenizer
=
transformers
.
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
self
.
tokenizer
=
transformers
.
GPT2Tokenizer
Fast
.
from_pretrained
(
'gpt2'
)
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
args
=
utils
.
simple_parse_args_string
(
arg_string
)
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
))
def
generate
(
self
,
context
,
max_gen_length
,
truncate
=
True
):
# when too long to fit in context, truncate from the left
context_tensor
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
context
.
strip
())[
max_gen_length
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
res
=
self
.
gpt2
.
generate
(
context_tensor
,
# TODO: change to have until rather than using eos_token_id
eos_token_id
=
self
.
tokenizer
.
eos_token_id
,
do_sample
=
False
,
max_length
=
self
.
num_tokens
(
context
)
+
max_gen_length
,
)
# chop off the prompt and the final eos token
return
self
.
tokenizer
.
decode
(
res
[
0
][
min
(
1024
-
max_gen_length
,
len
(
context_tensor
[
0
])):
-
1
]).
strip
()
def
loglikelihood
(
self
,
context
,
continuation
,
truncate
=
True
):
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
([
self
.
tokenizer
.
encode
(
context
+
continuation
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
ctxlen
=
len
(
self
.
tokenizer
.
encode
(
context
.
strip
()))
context_enc
=
self
.
tokenizer
.
encode
(
context
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
inp
=
torch
.
tensor
([(
context_enc
+
continuation_enc
)[
-
1024
:]],
dtype
=
torch
.
long
).
to
(
self
.
device
)
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
cont_toks
=
inp
[:,
ctxlen
:]
# [batch, seq]
logits
=
F
.
log_softmax
(
self
.
gpt2
(
inp
)[
0
],
dim
=-
1
)[:,
ctxlen
-
1
:
-
1
]
# [batch, seq, vocab]
return
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
def
num_tokens
(
self
,
string
):
return
len
(
self
.
tokenizer
.
tokenize
(
string
))
lm_eval/models/gpt3.py
View file @
12c2ee1e
...
...
@@ -18,7 +18,7 @@ class GPT3LM(LM):
"""
import
openai
self
.
engine
=
engine
self
.
tokenizer
=
transformers
.
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
self
.
tokenizer
=
transformers
.
GPT2Tokenizer
Fast
.
from_pretrained
(
'gpt2'
)
self
.
truncate
=
truncate
# Read from environment variable OPENAI_API_SECRET_KEY
...
...
@@ -29,49 +29,21 @@ class GPT3LM(LM):
args
=
utils
.
simple_parse_args_string
(
arg_string
)
return
cls
(
engine
=
args
.
get
(
"engine"
,
"davinci"
))
def
generate
(
self
,
context
,
max_gen_length
):
import
openai
if
self
.
truncate
:
prompt
=
self
.
smart_truncate
(
context
,
buffer
=
max_gen_length
)
else
:
prompt
=
context
response
=
openai
.
Completion
.
create
(
engine
=
self
.
engine
,
prompt
=
prompt
,
max_tokens
=
max_gen_length
,
temperature
=
0.0
,
)
return
response
.
choices
[
0
][
"text"
]
def
loglikelihood
(
self
,
context
,
continuation
):
import
openai
full_text
=
context
+
continuation
full_text_length
=
len
(
self
.
tokenizer
.
tokenize
(
full_text
))
context_length
=
len
(
self
.
tokenizer
.
tokenize
(
context
))
continuation_length
=
len
(
self
.
tokenizer
.
tokenize
(
continuation
))
assert
full_text_length
==
context_length
+
continuation_length
if
self
.
truncate
:
prompt
=
self
.
smart_truncate
(
full_text
,
buffer
=
0
)
else
:
prompt
=
full_text
context_enc
=
self
.
tokenizer
.
encode
(
context
)
continuation_enc
=
self
.
tokenizer
.
encode
(
continuation
)
inp
=
(
context_enc
+
continuation_enc
)[
-
1024
:]
ctxlen
=
len
(
context_enc
)
-
max
(
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
1024
)
response
=
openai
.
Completion
.
create
(
engine
=
self
.
engine
,
prompt
=
prompt
,
prompt
=
inp
,
echo
=
True
,
max_tokens
=
0
,
temperature
=
0.0
,
logprobs
=
0
,
)
logprobs
=
response
.
choices
[
0
][
"logprobs"
][
"token_logprobs"
]
continuation_logprobs
=
logprobs
[
-
continuation_length
:]
continuation_logprobs
=
logprobs
[
ctxlen
:]
return
sum
(
continuation_logprobs
)
def
smart_truncate
(
self
,
string
,
buffer
=
1
):
tokens
=
self
.
tokenizer
.
tokenize
(
string
)
available_length
=
self
.
MAX_LENGTH
-
1
-
buffer
# OpenAI adds 1 token
kept_tokens
=
tokens
[
-
available_length
:]
new_string
=
self
.
tokenizer
.
convert_tokens_to_string
(
kept_tokens
)
return
new_string
def
num_tokens
(
self
,
string
):
return
len
(
self
.
tokenizer
.
tokenize
(
string
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment