"vscode:/vscode.git/clone" did not exist on "ee7b19e74e760a43b5179319521c8cc4f55ef72f"
Unverified Commit 1bc6c933 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

openai: better error messages; fix greedy matching (#2327)



* better error message; fix greedy matching

* Update lm_eval/models/openai_completions.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* Update lm_eval/models/openai_completions.py
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>

* pre-commit

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 00f5537a
...@@ -69,11 +69,11 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -69,11 +69,11 @@ class LocalCompletionsAPI(TemplateAPI):
for choice, ctxlen in zip(out["choices"], ctxlens): for choice, ctxlen in zip(out["choices"], ctxlens):
assert ctxlen > 0, "Context length must be greater than 0" assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1]) logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens = choice["logprobs"]["token_logprobs"][ctxlen:-1] tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1] top_logprobs = choice["logprobs"]["top_logprobs"][ctxlen:-1]
is_greedy = True is_greedy = True
for tok, top in zip(tokens, top_logprobs): for tok, top in zip(tokens_logprobs, top_logprobs):
if tok != max(top, key=top.get): if tok != max(top.values()):
is_greedy = False is_greedy = False
break break
res.append((logprobs, is_greedy)) res.append((logprobs, is_greedy))
...@@ -190,14 +190,18 @@ class OpenAICompletionsAPI(LocalCompletionsAPI): ...@@ -190,14 +190,18 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
key = os.environ.get("OPENAI_API_KEY", None) key = os.environ.get("OPENAI_API_KEY", None)
if key is None: if key is None:
raise ValueError( raise ValueError(
"API key not found. Please set the OPENAI_API_KEY environment variable." "API key not found. Please set the `OPENAI_API_KEY` environment variable."
) )
return key return key
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
assert ( assert (
self.model != "gpt-3.5-turbo" self.model
), "Loglikelihood is not supported for gpt-3.5-turbo" in [
"babbage-002",
"davinci-002",
]
), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
return super().loglikelihood(requests, **kwargs) return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
...@@ -226,6 +230,11 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -226,6 +230,11 @@ class OpenAIChatCompletion(LocalChatCompletion):
key = os.environ.get("OPENAI_API_KEY", None) key = os.environ.get("OPENAI_API_KEY", None)
if key is None: if key is None:
raise ValueError( raise ValueError(
"API key not found. Please set the OPENAI_API_KEY environment variable." "API key not found. Please set the `OPENAI_API_KEY` environment variable."
) )
return key return key
def loglikelihood(self, requests, **kwargs):
raise NotImplementedError(
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment