"examples/vscode:/vscode.git/clone" did not exist on "fd6cec589afecf6b2de42817f2c3b6e3fe6b7de3"
Commit 0e1538e9 authored by DaveOkpare's avatar DaveOkpare
Browse files

Fix OpenaiChatCompletionsLM and remove is_async flag

parent 5ebc7c85
...@@ -60,7 +60,7 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open ...@@ -60,7 +60,7 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
backoff_time *= 1.5 backoff_time *= 1.5
def oa_chat_completion(is_async: bool = False, **kwargs): def oa_chat_completion(**kwargs):
"""Query OpenAI API for chat completion. """Query OpenAI API for chat completion.
Retry with back-off until they respond Retry with back-off until they respond
...@@ -76,10 +76,7 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open ...@@ -76,10 +76,7 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
if is_async: return openai.ChatCompletion.create(**kwargs)
return openai.ChatCompletion.acreate(**kwargs)
else:
return openai.ChatCompletion.create(**kwargs)
except openai.error.OpenAIError: except openai.error.OpenAIError:
import traceback import traceback
...@@ -88,16 +85,6 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open ...@@ -88,16 +85,6 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
backoff_time *= 1.5 backoff_time *= 1.5
async def oa_chat_completion_async(**kwargs):
"""Query async OpenAI API for chat completion.
Retry with back-off until they respond
"""
completion = await oa_chat_completion(is_async=True, **kwargs)
return completion
@register_model("openai", "openai-completions", "gooseai") @register_model("openai", "openai-completions", "gooseai")
class OpenaiCompletionsLM(LM): class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
...@@ -344,12 +331,78 @@ class OpenaiCompletionsLM(LM): ...@@ -344,12 +331,78 @@ class OpenaiCompletionsLM(LM):
return loglikelihoods return loglikelihoods
@register_model("openai", "openai-chat-completions", "gooseai") @register_model("openai-chat-completions")
class OpenaiChatCompletionsLM(OpenaiCompletionsLM): class OpenaiChatCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
def __init__( def __init__(
self, engine: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1 self, engine: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
) -> None: ) -> None:
super().__init__(engine, truncate, batch_size) """
:param engine: str
OpenAI API engine (e.g. gpt-3.5-turbo)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
self.engine = engine
self.tokenizer = tiktoken.encoding_for_model(self.engine)
self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@property
def eot_token_id(self):
return self.end_of_text_token_id
@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self) -> int:
return 256
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def greedy_until(self, requests) -> List[str]: def greedy_until(self, requests) -> List[str]:
if not requests: if not requests:
...@@ -378,17 +431,17 @@ class OpenaiChatCompletionsLM(OpenaiCompletionsLM): ...@@ -378,17 +431,17 @@ class OpenaiChatCompletionsLM(OpenaiCompletionsLM):
# todo: more intelligent batching for heterogeneous `until` # todo: more intelligent batching for heterogeneous `until`
for chunk, request_args in tqdm( for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
): ):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks):]
inps.append({"role": "user", "content": inp}) inps.append({"role": "user", "content": inp})
until = request_args.get("until", ["<|endoftext|>"]) until = request_args.get("until", ["<|endoftext|>"])
response = oa_completion( response = oa_chat_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
...@@ -413,3 +466,9 @@ class OpenaiChatCompletionsLM(OpenaiCompletionsLM): ...@@ -413,3 +466,9 @@ class OpenaiChatCompletionsLM(OpenaiCompletionsLM):
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment