Commit f66730c4 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed how messeges are sent to chatcompletions

parent a2fd682d
......@@ -10,9 +10,8 @@ from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from openai import OpenAI
client = OpenAI()
import asyncio
from openai import OpenAI, AsyncOpenAI
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
......@@ -314,7 +313,7 @@ class OpenaiCompletionsLM(LM):
return loglikelihoods
def oa_chat_completion(**kwargs):
def oa_chat_completion(client, **kwargs):
"""Query OpenAI API for chat completion.
Retry with back-off until they respond
......@@ -327,6 +326,10 @@ def oa_chat_completion(**kwargs):
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
async def _get_completions(**kwargs):
chat_completions = await client.chat.completions.create(**kwargs)
return chat_completions
backoff_time = 3
while True:
try:
......@@ -341,7 +344,6 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
@register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
......@@ -373,7 +375,8 @@ class OpenaiChatCompletionsLM(LM):
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_SECRET_KEY
# Read from environment variable OPENAI_API_KEY
self.client = OpenAI() # AsyncOpenAI()
@property
def eot_token_id(self):
......@@ -448,60 +451,64 @@ class OpenaiChatCompletionsLM(LM):
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items():
chunks = utils.chunks(re_ord.get_reordered(), n=self.REQ_CHUNK_SIZE)
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
response = oa_chat_completion(
client=self.client,
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
top_p=self.top_p,
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
response = oa_chat_completion(
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
top_p=self.top_p,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
res[key].append(s)
if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]
self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
)
pbar.update(1)
res[key].append(s)
self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment