Commit f66730c4 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixed how messeges are sent to chatcompletions

parent a2fd682d
...@@ -10,9 +10,8 @@ from lm_eval import utils ...@@ -10,9 +10,8 @@ from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from openai import OpenAI import asyncio
from openai import OpenAI, AsyncOpenAI
client = OpenAI()
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]: def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
...@@ -314,7 +313,7 @@ class OpenaiCompletionsLM(LM): ...@@ -314,7 +313,7 @@ class OpenaiCompletionsLM(LM):
return loglikelihoods return loglikelihoods
def oa_chat_completion(**kwargs): def oa_chat_completion(client, **kwargs):
"""Query OpenAI API for chat completion. """Query OpenAI API for chat completion.
Retry with back-off until they respond Retry with back-off until they respond
...@@ -327,6 +326,10 @@ def oa_chat_completion(**kwargs): ...@@ -327,6 +326,10 @@ def oa_chat_completion(**kwargs):
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
) )
async def _get_completions(**kwargs):
chat_completions = await client.chat.completions.create(**kwargs)
return chat_completions
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
...@@ -341,7 +344,6 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open ...@@ -341,7 +344,6 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
@register_model("openai-chat-completions") @register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(LM): class OpenaiChatCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
def __init__( def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1 self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
...@@ -373,7 +375,8 @@ class OpenaiChatCompletionsLM(LM): ...@@ -373,7 +375,8 @@ class OpenaiChatCompletionsLM(LM):
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_KEY
self.client = OpenAI() # AsyncOpenAI()
@property @property
def eot_token_id(self): def eot_token_id(self):
...@@ -448,7 +451,10 @@ class OpenaiChatCompletionsLM(LM): ...@@ -448,7 +451,10 @@ class OpenaiChatCompletionsLM(LM):
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
chunks = utils.chunks(re_ord.get_reordered(), n=self.REQ_CHUNK_SIZE) # n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts] inps = [{"role": "user", "content": context} for context in contexts]
...@@ -476,6 +482,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -476,6 +482,7 @@ class OpenaiChatCompletionsLM(LM):
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
response = oa_chat_completion( response = oa_chat_completion(
client=self.client,
messages=inps, messages=inps,
model=self.model, model=self.model,
frequency_penalty=self.frequency_penalty, frequency_penalty=self.frequency_penalty,
...@@ -501,7 +508,7 @@ class OpenaiChatCompletionsLM(LM): ...@@ -501,7 +508,7 @@ class OpenaiChatCompletionsLM(LM):
"generate_until", (context, {"until": until}), s "generate_until", (context, {"until": until}), s
) )
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key]) res[key] = re_ord.get_original(res[key])
pbar.close() pbar.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment