Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f66730c4
Commit
f66730c4
authored
Nov 27, 2023
by
lintangsutawika
Browse files
fixed how messeges are sent to chatcompletions
parent
a2fd682d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
51 deletions
+58
-51
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+58
-51
No files found.
lm_eval/models/openai_completions.py
View file @
f66730c4
...
...
@@ -10,9 +10,8 @@ from lm_eval import utils
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
openai
import
OpenAI
client
=
OpenAI
()
import
asyncio
from
openai
import
OpenAI
,
AsyncOpenAI
def
get_result
(
response
:
dict
,
ctxlen
:
int
)
->
Tuple
[
float
,
bool
]:
...
...
@@ -314,7 +313,7 @@ class OpenaiCompletionsLM(LM):
return
loglikelihoods
def
oa_chat_completion
(
**
kwargs
):
def
oa_chat_completion
(
client
,
**
kwargs
):
"""Query OpenAI API for chat completion.
Retry with back-off until they respond
...
...
@@ -327,6 +326,10 @@ def oa_chat_completion(**kwargs):
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
)
async
def
_get_completions
(
**
kwargs
):
chat_completions
=
await
client
.
chat
.
completions
.
create
(
**
kwargs
)
return
chat_completions
backoff_time
=
3
while
True
:
try
:
...
...
@@ -341,7 +344,6 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
@
register_model
(
"openai-chat-completions"
)
class
OpenaiChatCompletionsLM
(
LM
):
REQ_CHUNK_SIZE
=
20
def
__init__
(
self
,
model
:
str
=
"gpt-3.5-turbo"
,
truncate
:
bool
=
False
,
batch_size
:
int
=
1
...
...
@@ -373,7 +375,8 @@ class OpenaiChatCompletionsLM(LM):
self
.
truncate
=
truncate
self
.
end_of_text_token_id
=
self
.
tokenizer
.
eot_token
# Read from environment variable OPENAI_API_SECRET_KEY
# Read from environment variable OPENAI_API_KEY
self
.
client
=
OpenAI
()
# AsyncOpenAI()
@
property
def
eot_token_id
(
self
):
...
...
@@ -448,60 +451,64 @@ class OpenaiChatCompletionsLM(LM):
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
for
key
,
re_ord
in
re_ords
.
items
():
chunks
=
utils
.
chunks
(
re_ord
.
get_reordered
(),
n
=
self
.
REQ_CHUNK_SIZE
)
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks
=
utils
.
chunks
(
re_ord
.
get_reordered
(),
n
=
1
)
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
inps
=
[{
"role"
:
"user"
,
"content"
:
context
}
for
context
in
contexts
]
gen_kwargs
=
all_gen_kwargs
[
0
]
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
kwargs
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
kwargs
}
"
gen_kwargs
=
all_gen_kwargs
[
0
]
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
kwargs
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
kwargs
}
"
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
response
=
oa_chat_completion
(
client
=
self
.
client
,
messages
=
inps
,
model
=
self
.
model
,
frequency_penalty
=
self
.
frequency_penalty
,
# logit_bias=self.logit_bias,
max_tokens
=
max_gen_toks
,
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
response
=
oa_chat_completion
(
messages
=
inps
,
model
=
self
.
model
,
frequency_penalty
=
self
.
frequency_penalty
,
# logit_bias=self.logit_bias,
max_tokens
=
max_gen_toks
,
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
)
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
.
message
.
content
if
until
is
not
None
:
for
term
in
until
:
if
len
(
term
)
>
0
:
s
=
s
.
split
(
term
)[
0
]
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
.
message
.
content
res
[
key
].
append
(
s
)
if
until
is
not
None
:
for
term
in
until
:
if
len
(
term
)
>
0
:
s
=
s
.
split
(
term
)[
0
]
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
{
"until"
:
until
}),
s
)
pbar
.
update
(
1
)
res
[
key
].
append
(
s
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
{
"until"
:
until
}),
s
)
pbar
.
update
(
1
)
# reorder this group of results back to original unsorted form
res
[
key
]
=
re_ord
.
get_original
(
res
[
key
])
pbar
.
close
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment