Commit 5a5442ff authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

make OAI completions work with new generation kwargs format

parent d1c5abef
......@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM):
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
......@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
try:
until = request_args["until"][
0
] # TODO: does this handle a list of stop seqs correctly?
except KeyError:
until = "<|endoftext|>"
response = oa_completion(
engine=self.engine,
prompt=inps,
......@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM):
stop=until,
)
for resp, (context, until_) in zip(response.choices, chunk):
for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"]
until_ = args_.get(["until"], [])
for term in until_:
s = s.split(term)[0]
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s)
self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
)
res.append(s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment