Unverified Commit f1b64f68 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #1008 from EleutherAI/openai_completions

[Refactor] Openai completions
parents c3d97e40 5b42436b
import os
import time
from typing import List, Tuple
import copy
from collections import defaultdict
from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
......@@ -51,7 +55,7 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
return openai.Completions.create(**kwargs)
except openai.error.OpenAIError:
import traceback
......@@ -60,7 +64,7 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
backoff_time *= 1.5
@register_model("openai", "openai-completions", "gooseai")
@register_model("gooseai")
class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
......@@ -304,3 +308,211 @@ class OpenaiCompletionsLM(LM):
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def oa_chat_completion(client, **kwargs):
"""Query OpenAI API for chat completion.
Retry with back-off until they respond
"""
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
async def _get_completions(**kwargs):
chat_completions = await client.chat.completions.create(**kwargs)
return chat_completions
backoff_time = 3
while True:
try:
return client.chat.completions.create(**kwargs)
except openai.OpenAIError:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
@register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(LM):
def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
) -> None:
"""
:param model: str
OpenAI API model (e.g. gpt-3.5-turbo)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
self.model = model
self.frequency_penalty = 0
self.logit_bias = None
self.n = 1
self.presence_penalty = 0
self.temperature = 1
self.top_p = 1
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_KEY
self.client = openai.OpenAI() # openai.AsyncOpenAI()
@property
def eot_token_id(self):
return self.end_of_text_token_id
@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self) -> int:
return 256
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def generate_until(self, requests) -> List[str]:
res = defaultdict(list)
re_ords = {}
def _collate(x):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items():
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
gen_kwargs = all_gen_kwargs[0]
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
response = oa_chat_completion(
client=self.client,
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
top_p=self.top_p,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp.message.content
if until is not None:
for term in until:
if len(term) > 0:
s = s.split(term)[0]
res[key].append(s)
self.cache_hook.add_partial(
"generate_until", (context, {"until": until}), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close()
return grouper.get_original(res)
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
......@@ -70,7 +70,7 @@ promptsource = [
]
gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"]
anthropic = ["anthropic"]
openai = ["openai", "tiktoken"]
openai = ["openai>=1.3.5", "tiktoken"]
vllm = ["vllm"]
all = [
"lm_eval[dev]",
......
import hashlib
import json
import openai
import os
import pickle
import pytest
......@@ -8,6 +7,10 @@ import unittest.mock as mock
import lm_eval.models as models
from openai import OpenAI
client = OpenAI()
LOGLIKELIHOOD_TEST_CASES = [
("The quick brown fox jumps over the lazy", " dog"),
......@@ -172,7 +175,7 @@ def openai_mock_completion(**kwargs):
if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
ret = openai.Completion.create(**kwargs)
ret = client.completions.create(**kwargs)
ret.api_key = ""
with open(fname, "wb") as fh:
pickle.dump(ret, fh)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment