Unverified Commit 1bc6c933 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

openai: better error messages; fix greedy matching (#2327)



* better error message; fix greedy matching

* Update lm_eval/models/openai_completions.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/models/openai_completions.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* pre-commit

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 00f5537a
......@@ -69,11 +69,11 @@ class LocalCompletionsAPI(TemplateAPI):
for choice, ctxlen in zip(out["choices"], ctxlens):
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens = choice["logprobs"]["token_logprobs"][ctxlen:-1]
tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1]
is_greedy = True
for tok, top in zip(tokens, top_logprobs):
if tok != max(top, key=top.get):
for tok, top in zip(tokens_logprobs, top_logprobs):
if tok != max(top.values()):
is_greedy = False
break
res.append((logprobs, is_greedy))
......@@ -190,14 +190,18 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
key = os.environ.get("OPENAI_API_KEY", None)
if key is None:
raise ValueError(
"API key not found. Please set the OPENAI_API_KEY environment variable."
"API key not found. Please set the `OPENAI_API_KEY` environment variable."
)
return key
def loglikelihood(self, requests, **kwargs):
assert (
self.model != "gpt-3.5-turbo"
), "Loglikelihood is not supported for gpt-3.5-turbo"
self.model
in [
"babbage-002",
"davinci-002",
]
), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......@@ -226,6 +230,11 @@ class OpenAIChatCompletion(LocalChatCompletion):
key = os.environ.get("OPENAI_API_KEY", None)
if key is None:
raise ValueError(
"API key not found. Please set the OPENAI_API_KEY environment variable."
"API key not found. Please set the `OPENAI_API_KEY` environment variable."
)
return key
def loglikelihood(self, requests, **kwargs):
raise NotImplementedError(
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment