Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
30296795
Commit
30296795
authored
Nov 24, 2023
by
lintangsutawika
Browse files
update
parent
12e92616
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
33 deletions
+60
-33
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+56
-30
tests/tests_master/test_models.py
tests/tests_master/test_models.py
+4
-3
No files found.
lm_eval/models/openai_completions.py
View file @
30296795
import
os
import
time
from
typing
import
List
,
Tuple
import
copy
from
collections
import
defaultdict
from
tqdm
import
tqdm
from
lm_eval
import
utils
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
...
...
@@ -10,6 +14,7 @@ from openai import OpenAI
client
=
OpenAI
()
def
oa_chat_completion
(
**
kwargs
):
"""Query OpenAI API for chat completion.
...
...
@@ -40,7 +45,7 @@ class OpenaiChatCompletionsLM(LM):
REQ_CHUNK_SIZE
=
20
def
__init__
(
self
,
model
:
str
=
"gpt-3.5-turbo"
,
truncate
:
bool
=
False
,
batch_size
:
int
=
1
self
,
model
:
str
=
"gpt-3.5-turbo"
,
truncate
:
bool
=
False
,
batch_size
:
int
=
1
)
->
None
:
"""
...
...
@@ -70,7 +75,6 @@ class OpenaiChatCompletionsLM(LM):
self
.
end_of_text_token_id
=
self
.
tokenizer
.
eot_token
# Read from environment variable OPENAI_API_SECRET_KEY
@
property
def
eot_token_id
(
self
):
...
...
@@ -102,7 +106,7 @@ class OpenaiChatCompletionsLM(LM):
return
self
.
tokenizer
.
decode
(
tokens
)
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
self
,
context
:
str
,
continuation
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
...
...
@@ -115,16 +119,20 @@ class OpenaiChatCompletionsLM(LM):
return
context_enc
,
continuation_enc
def
generate_until
(
self
,
requests
)
->
List
[
str
]:
if
not
requests
:
return
[]
res
=
[]
requests
=
[
req
.
args
for
req
in
requests
]
res
=
defaultdict
(
list
)
re_ords
=
{}
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
len
(
toks
),
x
[
0
]
return
-
len
(
toks
),
x
[
0
]
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper
=
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
.
args
[
1
]))
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords
[
key
]
=
utils
.
Reorderer
([
req
.
args
for
req
in
reqs
],
_collate
)
def
sameuntil_chunks
(
xs
,
size
):
ret
=
[]
...
...
@@ -139,25 +147,41 @@ class OpenaiChatCompletionsLM(LM):
if
ret
:
yield
ret
,
lastuntil
# todo: more intelligent batching for heterogeneous `until`
for
chunk
,
request_args
in
tqdm
(
list
(
sameuntil_chunks
(
re_ord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))
):
inps
=
[]
for
context
,
_
in
chunk
:
# context_enc = self.tok_encode(context)
# inp = context_enc[-(self.max_length - self.max_gen_toks):]
inps
.
append
({
"role"
:
"user"
,
"content"
:
context
})
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
for
key
,
re_ord
in
re_ords
.
items
():
chunks
=
utils
.
chunks
(
re_ord
.
get_reordered
(),
n
=
self
.
REQ_CHUNK_SIZE
)
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
inps
=
[{
"role"
:
"user"
,
"content"
:
context
}
for
context
in
contexts
]
gen_kwargs
=
all_gen_kwargs
[
0
]
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
kwargs
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
kwargs
}
"
)
# until = request_args.get("until", ["<|endoftext|>"])
until
=
request_args
.
get
(
"until"
,
None
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
response
=
oa_chat_completion
(
messages
=
inps
,
model
=
self
.
model
,
frequency_penalty
=
self
.
frequency_penalty
,
# logit_bias=self.logit_bias,
max_tokens
=
self
.
max_gen_toks
,
max_tokens
=
max_gen_toks
,
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
temperature
=
self
.
temperature
,
...
...
@@ -167,21 +191,23 @@ class OpenaiChatCompletionsLM(LM):
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
.
message
.
content
# until_ = args_.get("until", ["<|endoftext|>"])
until_
=
args_
.
get
(
"until"
,
None
)
if
until_
is
not
None
:
for
term
in
until_
:
if
until
is
not
None
:
for
term
in
until
:
if
len
(
term
)
>
0
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
res
[
key
].
append
(
s
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
{
"until"
:
until
_
}),
s
"generate_until"
,
(
context
,
{
"until"
:
until
}),
s
)
pbar
.
update
(
1
)
res
[
key
]
=
re_ord
.
get_original
(
res
[
key
])
pbar
.
close
()
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
grouper
.
get_original
(
res
)
def
loglikelihood
(
self
,
requests
):
raise
NotImplementedError
(
"No support for logits."
)
...
...
tests/tests_master/test_models.py
View file @
30296795
import
hashlib
import
json
from
openai
import
OpenAI
client
=
OpenAI
()
import
os
import
pickle
import
pytest
...
...
@@ -10,6 +7,10 @@ import unittest.mock as mock
import
lm_eval.models
as
models
from
openai
import
OpenAI
client
=
OpenAI
()
LOGLIKELIHOOD_TEST_CASES
=
[
(
"The quick brown fox jumps over the lazy"
,
" dog"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment