Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8f5b2295
Unverified
Commit
8f5b2295
authored
Dec 16, 2023
by
Baber Abbasi
Committed by
GitHub
Dec 16, 2023
Browse files
openai nits (#1139)
* fixed syntactic nits * fix temperature and seed * fix logprobs * fixup merge
parent
f7c67f0e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
18 deletions
+29
-18
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+29
-18
No files found.
lm_eval/models/openai_completions.py
View file @
8f5b2295
import
os
import
os
import
time
import
time
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Optional
import
copy
import
copy
from
collections
import
defaultdict
from
collections
import
defaultdict
...
@@ -11,7 +11,7 @@ from lm_eval.api.model import LM
...
@@ -11,7 +11,7 @@ from lm_eval.api.model import LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
def
get_result
(
response
:
dict
,
ctxlen
:
int
)
->
Tuple
[
float
,
bool
]:
def
get_result
(
response
,
ctxlen
:
int
)
->
Tuple
[
float
,
bool
]:
"""Process results from OpenAI API response.
"""Process results from OpenAI API response.
:param response: dict
:param response: dict
...
@@ -25,12 +25,12 @@ def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
...
@@ -25,12 +25,12 @@ def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
whether argmax matches given continuation exactly
whether argmax matches given continuation exactly
"""
"""
is_greedy
=
True
is_greedy
=
True
logprobs
=
response
[
"
logprobs
"
][
"
token_logprobs
"
]
logprobs
=
response
.
logprobs
.
token_logprobs
continuation_logprobs
=
sum
(
logprobs
[
ctxlen
:])
continuation_logprobs
=
sum
(
logprobs
[
ctxlen
:])
for
i
in
range
(
ctxlen
,
len
(
response
[
"
logprobs
"
][
"
token
s"
]
)):
for
i
in
range
(
ctxlen
,
len
(
response
.
logprobs
.
token
_logprobs
)):
token
=
response
[
"
logprobs
"
][
"
token
s"
]
[
i
]
token
=
response
.
logprobs
.
token
_logprobs
[
i
]
top_tokens
=
response
[
"
logprobs
"
][
"
top_logprobs
"
]
[
i
]
top_tokens
=
response
.
logprobs
.
top_logprobs
[
i
]
top_token
=
max
(
top_tokens
.
keys
(),
key
=
lambda
x
:
top_tokens
[
x
])
top_token
=
max
(
top_tokens
.
keys
(),
key
=
lambda
x
:
top_tokens
[
x
])
if
top_token
!=
token
:
if
top_token
!=
token
:
is_greedy
=
False
is_greedy
=
False
...
@@ -67,12 +67,16 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
...
@@ -67,12 +67,16 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
@
register_model
(
"openai-completions"
)
@
register_model
(
"openai-completions"
)
class
OpenaiCompletionsLM
(
LM
):
class
OpenaiCompletionsLM
(
LM
):
REQ_CHUNK_SIZE
=
20
REQ_CHUNK_SIZE
=
20
_DEFAULT_MAX_LENGTH
=
2048
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
=
"text-davinci-003"
,
model
:
str
=
"text-davinci-003"
,
truncate
:
bool
=
False
,
truncate
:
bool
=
False
,
max_gen_toks
:
int
=
256
,
batch_size
:
int
=
1
,
batch_size
:
int
=
1
,
seed
:
int
=
1234
,
max_length
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -82,6 +86,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -82,6 +86,7 @@ class OpenaiCompletionsLM(LM):
Truncate input if too long (if False and input is too long, throw error)
Truncate input if too long (if False and input is too long, throw error)
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
seed
=
seed
try
:
try
:
import
openai
,
tiktoken
# noqa: E401
import
openai
,
tiktoken
# noqa: E401
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
...
@@ -89,14 +94,16 @@ class OpenaiCompletionsLM(LM):
...
@@ -89,14 +94,16 @@ class OpenaiCompletionsLM(LM):
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
)
)
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tiktoken
.
encoding_for_model
(
self
.
model
)
self
.
tokenizer
=
tiktoken
.
encoding_for_model
(
self
.
model
)
self
.
vocab_size
=
self
.
tokenizer
.
n_vocab
self
.
vocab_size
=
self
.
tokenizer
.
n_vocab
self
.
truncate
=
truncate
self
.
truncate
=
truncate
self
.
end_of_text_token_id
=
self
.
tokenizer
.
eot_token
self
.
end_of_text_token_id
=
self
.
tokenizer
.
eot_token
self
.
_max_gen_toks
=
max_gen_toks
self
.
_max_length
=
max_length
# Read from environment variable OPENAI_API_SECRET_KEY
# Read from environment variable OPENAI_API_SECRET_KEY
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_
SECRET_
KEY"
]
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_KEY"
]
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
...
@@ -104,12 +111,14 @@ class OpenaiCompletionsLM(LM):
...
@@ -104,12 +111,14 @@ class OpenaiCompletionsLM(LM):
@
property
@
property
def
max_length
(
self
)
->
int
:
def
max_length
(
self
)
->
int
:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
if
self
.
_max_length
:
return
2048
return
self
.
_max_length
else
:
return
self
.
_DEFAULT_MAX_LENGTH
@
property
@
property
def
max_gen_toks
(
self
)
->
int
:
def
max_gen_toks
(
self
)
->
int
:
return
256
return
self
.
_max_gen_toks
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -187,12 +196,13 @@ class OpenaiCompletionsLM(LM):
...
@@ -187,12 +196,13 @@ class OpenaiCompletionsLM(LM):
ctxlens
.
append
(
ctxlen
)
ctxlens
.
append
(
ctxlen
)
response
=
oa_completion
(
response
=
oa_completion
(
engine
=
self
.
engine
,
model
=
self
.
model
,
prompt
=
inps
,
prompt
=
inps
,
echo
=
True
,
echo
=
True
,
max_tokens
=
0
,
max_tokens
=
0
,
temperature
=
0.0
,
temperature
=
0.0
,
logprobs
=
10
,
logprobs
=
10
,
seed
=
self
.
seed
,
)
)
for
resp
,
ctxlen
,
(
cache_key
,
context_enc
,
continuation_enc
)
in
zip
(
for
resp
,
ctxlen
,
(
cache_key
,
context_enc
,
continuation_enc
)
in
zip
(
...
@@ -242,21 +252,22 @@ class OpenaiCompletionsLM(LM):
...
@@ -242,21 +252,22 @@ class OpenaiCompletionsLM(LM):
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inps
.
append
(
inp
)
inps
.
append
(
inp
)
until
=
request_args
.
get
(
"until"
,
[
"<|endoftext|>"
])
until
=
request_args
.
pop
(
"until"
,
[
"<|endoftext|>"
])
request_args
.
pop
(
"do_sample"
,
None
)
request_args
[
"temperature"
]
=
request_args
.
get
(
"temperature"
,
0
)
response
=
oa_completion
(
response
=
oa_completion
(
model
=
self
.
model
,
model
=
self
.
model
,
prompt
=
inps
,
prompt
=
inps
,
max_tokens
=
self
.
max_gen_toks
,
max_tokens
=
self
.
max_gen_toks
,
temperature
=
0.0
,
logprobs
=
10
,
stop
=
until
,
stop
=
until
,
seed
=
self
.
seed
,
**
request_args
,
)
)
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
getattr
(
resp
,
'
text
'
)
s
=
getattr
(
resp
,
"
text
"
)
until_
=
args_
.
get
(
"until"
,
[
"<|endoftext|>"
])
until_
=
until
for
term
in
until_
:
for
term
in
until_
:
if
len
(
term
)
>
0
:
if
len
(
term
)
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment