Commit 5ebc7c85 authored by David Okpare's avatar David Okpare
Browse files

Added support for ChatCompletions

parent 3839125a
......@@ -60,6 +60,44 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
backoff_time *= 1.5
def oa_chat_completion(is_async: bool = False, **kwargs):
"""Query OpenAI API for chat completion.
Retry with back-off until they respond
"""
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
backoff_time = 3
while True:
try:
if is_async:
return openai.ChatCompletion.acreate(**kwargs)
else:
return openai.ChatCompletion.create(**kwargs)
except openai.error.OpenAIError:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
async def oa_chat_completion_async(**kwargs):
"""Query async OpenAI API for chat completion.
Retry with back-off until they respond
"""
completion = await oa_chat_completion(is_async=True, **kwargs)
return completion
@register_model("openai", "openai-completions", "gooseai")
class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
......@@ -304,3 +342,74 @@ class OpenaiCompletionsLM(LM):
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
@register_model("openai", "openai-chat-completions", "gooseai")
class OpenaiChatCompletionsLM(OpenaiCompletionsLM):
def __init__(
self, engine: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
) -> None:
super().__init__(engine, truncate, batch_size)
def greedy_until(self, requests) -> List[str]:
if not requests:
return []
res = []
requests = [req.args for req in requests]
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append({"role": "user", "content": inp})
until = request_args.get("until", ["<|endoftext|>"])
response = oa_completion(
engine=self.engine,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.0,
logprobs=10,
stop=until,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"]
until_ = args_.get("until", ["<|endoftext|>"])
for term in until_:
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
)
res.append(s)
return re_ord.get_original(res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment