Unverified Commit 8f5b2295 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

openai nits (#1139)

* fixed syntactic nits

* fix temperature and seed

* fix logprobs

* fixup merge
parent f7c67f0e
import os
import time
from typing import List, Tuple
from typing import List, Tuple, Optional
import copy
from collections import defaultdict
......@@ -11,7 +11,7 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
def get_result(response, ctxlen: int) -> Tuple[float, bool]:
"""Process results from OpenAI API response.
:param response: dict
......@@ -25,12 +25,12 @@ def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
whether argmax matches given continuation exactly
"""
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
logprobs = response.logprobs.token_logprobs
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
top_tokens = response["logprobs"]["top_logprobs"][i]
for i in range(ctxlen, len(response.logprobs.token_logprobs)):
token = response.logprobs.token_logprobs[i]
top_tokens = response.logprobs.top_logprobs[i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
......@@ -67,12 +67,16 @@ please install these via `pip install lm-eval[openai]` or `pip install -e .[open
@register_model("openai-completions")
class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
model: str = "text-davinci-003",
truncate: bool = False,
max_gen_toks: int = 256,
batch_size: int = 1,
seed: int = 1234,
max_length: Optional[int] = None,
) -> None:
"""
......@@ -82,6 +86,7 @@ class OpenaiCompletionsLM(LM):
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
self.seed = seed
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
......@@ -89,14 +94,16 @@ class OpenaiCompletionsLM(LM):
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
self.model= model
self.model = model
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token
self._max_gen_toks = max_gen_toks
self._max_length = max_length
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
openai.api_key = os.environ["OPENAI_API_KEY"]
@property
def eot_token_id(self):
......@@ -104,12 +111,14 @@ class OpenaiCompletionsLM(LM):
@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
if self._max_length:
return self._max_length
else:
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self) -> int:
return 256
return self._max_gen_toks
@property
def batch_size(self):
......@@ -187,12 +196,13 @@ class OpenaiCompletionsLM(LM):
ctxlens.append(ctxlen)
response = oa_completion(
engine=self.engine,
model=self.model,
prompt=inps,
echo=True,
max_tokens=0,
temperature=0.0,
logprobs=10,
seed=self.seed,
)
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
......@@ -242,21 +252,22 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
until = request_args.get("until", ["<|endoftext|>"])
until = request_args.pop("until", ["<|endoftext|>"])
request_args.pop("do_sample", None)
request_args["temperature"] = request_args.get("temperature", 0)
response = oa_completion(
model=self.model,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.0,
logprobs=10,
stop=until,
seed=self.seed,
**request_args,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = getattr(resp, 'text')
s = getattr(resp, "text")
until_ = args_.get("until", ["<|endoftext|>"])
until_ = until
for term in until_:
if len(term) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment