Unverified Commit 8f5b2295 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

openai nits (#1139)

* fixed syntactic nits

* fix temperature and seed

* fix logprobs

* fixup merge
parent f7c67f0e
import os import os
import time import time
from typing import List, Tuple from typing import List, Tuple, Optional
import copy import copy
from collections import defaultdict from collections import defaultdict
...@@ -11,7 +11,7 @@ from lm_eval.api.model import LM ...@@ -11,7 +11,7 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]: def get_result(response, ctxlen: int) -> Tuple[float, bool]:
"""Process results from OpenAI API response. """Process results from OpenAI API response.
:param response: dict :param response: dict
...@@ -25,12 +25,12 @@ def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]: ...@@ -25,12 +25,12 @@ def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
whether argmax matches given continuation exactly whether argmax matches given continuation exactly
""" """
is_greedy = True is_greedy = True
logprobs = response["logprobs"]["token_logprobs"] logprobs = response.logprobs.token_logprobs
continuation_logprobs = sum(logprobs[ctxlen:]) continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])): for i in range(ctxlen, len(response.logprobs.token_logprobs)):
token = response["logprobs"]["tokens"][i] token = response.logprobs.token_logprobs[i]
top_tokens = response["logprobs"]["top_logprobs"][i] top_tokens = response.logprobs.top_logprobs[i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x]) top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token: if top_token != token:
is_greedy = False is_greedy = False
...@@ -67,12 +67,16 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open ...@@ -67,12 +67,16 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
@register_model("openai-completions") @register_model("openai-completions")
class OpenaiCompletionsLM(LM): class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
_DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
self, self,
model: str = "text-davinci-003", model: str = "text-davinci-003",
truncate: bool = False, truncate: bool = False,
max_gen_toks: int = 256,
batch_size: int = 1, batch_size: int = 1,
seed: int = 1234,
max_length: Optional[int] = None,
) -> None: ) -> None:
""" """
...@@ -82,6 +86,7 @@ class OpenaiCompletionsLM(LM): ...@@ -82,6 +86,7 @@ class OpenaiCompletionsLM(LM):
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__() super().__init__()
self.seed = seed
try: try:
import openai, tiktoken # noqa: E401 import openai, tiktoken # noqa: E401
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -89,14 +94,16 @@ class OpenaiCompletionsLM(LM): ...@@ -89,14 +94,16 @@ class OpenaiCompletionsLM(LM):
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
) )
self.model= model self.model = model
self.tokenizer = tiktoken.encoding_for_model(self.model) self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token self.end_of_text_token_id = self.tokenizer.eot_token
self._max_gen_toks = max_gen_toks
self._max_length = max_length
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_KEY"]
@property @property
def eot_token_id(self): def eot_token_id(self):
...@@ -104,12 +111,14 @@ class OpenaiCompletionsLM(LM): ...@@ -104,12 +111,14 @@ class OpenaiCompletionsLM(LM):
@property @property
def max_length(self) -> int: def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token if self._max_length:
return 2048 return self._max_length
else:
return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self) -> int: def max_gen_toks(self) -> int:
return 256 return self._max_gen_toks
@property @property
def batch_size(self): def batch_size(self):
...@@ -187,12 +196,13 @@ class OpenaiCompletionsLM(LM): ...@@ -187,12 +196,13 @@ class OpenaiCompletionsLM(LM):
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
response = oa_completion( response = oa_completion(
engine=self.engine, model=self.model,
prompt=inps, prompt=inps,
echo=True, echo=True,
max_tokens=0, max_tokens=0,
temperature=0.0, temperature=0.0,
logprobs=10, logprobs=10,
seed=self.seed,
) )
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip( for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
...@@ -242,21 +252,22 @@ class OpenaiCompletionsLM(LM): ...@@ -242,21 +252,22 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp) inps.append(inp)
until = request_args.get("until", ["<|endoftext|>"]) until = request_args.pop("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion( response = oa_completion(
model=self.model, model=self.model,
prompt=inps, prompt=inps,
max_tokens=self.max_gen_toks, max_tokens=self.max_gen_toks,
temperature=0.0,
logprobs=10,
stop=until, stop=until,
seed=self.seed,
**request_args,
) )
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, 'text') s = getattr(resp, "text")
until_ = args_.get("until", ["<|endoftext|>"]) until_ = until
for term in until_: for term in until_:
if len(term) > 0: if len(term) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment