"docs/vscode:/vscode.git/clone" did not exist on "2ca227f8144f1bc2328a572017904f45afd4812d"
Commit 306e84fb authored by lintangsutawika's avatar lintangsutawika
Browse files

readd old completion api

parent 30296795
...@@ -15,6 +15,305 @@ from openai import OpenAI ...@@ -15,6 +15,305 @@ from openai import OpenAI
client = OpenAI() client = OpenAI()
def get_result(response: dict, ctxlen: int) -> Tuple[float, bool]:
"""Process results from OpenAI API response.
:param response: dict
OpenAI API Response
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
continuation_logprobs: np.array
Log probabilities of continuation tokens
is_greedy: bool
whether argmax matches given continuation exactly
"""
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
top_tokens = response["logprobs"]["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
def oa_completion(**kwargs):
"""Query OpenAI API for completion.
Retry with back-off until they respond
"""
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
except openai.error.OpenAIError:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
@register_model("openai", "openai-completions", "gooseai")
class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20
def __init__(
self,
engine: str = "text-davinci-003",
truncate: bool = False,
batch_size: int = 1,
) -> None:
"""
:param engine: str
OpenAI API engine (e.g. davinci)
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
super().__init__()
try:
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
self.engine = engine
self.tokenizer = tiktoken.encoding_for_model(self.engine)
self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@property
def eot_token_id(self):
return self.end_of_text_token_id
@property
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self) -> int:
return 256
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else:
context_enc, continuation_enc = self._encode_pair(context, continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
res = []
def _collate(x):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about, and so we need some kind of backup for when it isn't
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)),
disable=disable_tqdm,
):
inps = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
# max_length+1 because the API takes up to 2049 tokens, including the first context token
inp = (context_enc + continuation_enc)[-(self.max_length + 1) :]
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length + 1)
)
inps.append(inp)
ctxlens.append(ctxlen)
response = oa_completion(
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0,
temperature=0.0,
logprobs=10,
)
for resp, ctxlen, (cache_key, context_enc, continuation_enc) in zip(
response.choices, ctxlens, chunk
):
answer = get_result(resp, ctxlen)
res.append(answer)
# partial caching
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def generate_until(self, requests) -> List[str]:
if not requests:
return []
res = []
requests = [req.args for req in requests]
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer(requests, _collate)
def sameuntil_chunks(xs, size):
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
until = request_args.get("until", ["<|endoftext|>"])
response = oa_completion(
engine=self.engine,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.0,
logprobs=10,
stop=until,
)
for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"]
until_ = args_.get("until", ["<|endoftext|>"])
for term in until_:
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial(
"generate_until", (context, {"until": until_}), s
)
res.append(s)
return re_ord.get_original(res)
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood_rolling(self, requests) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
disable_tqdm=True,
)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def oa_chat_completion(**kwargs): def oa_chat_completion(**kwargs):
"""Query OpenAI API for chat completion. """Query OpenAI API for chat completion.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment