Commit f3d2bf90 authored by Baber's avatar Baber
Browse files

add generation to apimodels

parent 2ef15732
......@@ -647,3 +647,18 @@ class TemplateAPI(TemplateLM):
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
return loglikelihoods
def simple_async_generate(
self, requests: Union[List[List[str]], List[List[dict]]], gen_kwargs: dict
):
results = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
requests,
cache_keys=[None] * len(requests),
generate=True,
gen_kwargs=gen_kwargs,
)
)
)
return list(results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment