Commit e63ad693 authored by Michael Yang's avatar Michael Yang
Browse files

batch mode

parent 91542980
......@@ -54,15 +54,18 @@ def list_models(*args, **kwargs):
def generate(*args, **kwargs):
if prompt := kwargs.get('prompt'):
print('>>>', prompt, flush=True)
print(flush=True)
generate_oneshot(*args, **kwargs)
print(flush=True)
return
return generate_interactive(*args, **kwargs)
if sys.stdin.isatty():
return generate_interactive(*args, **kwargs)
return generate_batch(*args, **kwargs)
def generate_oneshot(*args, **kwargs):
print(flush=True)
for output in engine.generate(*args, **kwargs):
output = json.loads(output)
choices = output.get("choices", [])
......@@ -70,20 +73,26 @@ def generate_oneshot(*args, **kwargs):
print(choices[0].get("text", ""), end="", flush=True)
# end with a new line
print()
print(flush=True)
print(flush=True)
def generate_interactive(*args, **kwargs):
print('>>> ', end='', flush=True)
for line in sys.stdin:
if not sys.stdin.isatty():
print(line, end='')
while True:
print('>>> ', end='', flush=True)
line = next(sys.stdin)
if not line:
return
print(flush=True)
kwargs.update({'prompt': line})
kwargs.update({"prompt": line})
generate_oneshot(*args, **kwargs)
def generate_batch(*args, **kwargs):
for line in sys.stdin:
print('>>> ', line, end='', flush=True)
kwargs.update({"prompt": line})
generate_oneshot(*args, **kwargs)
print(flush=True)
print('>>> ', end='', flush=True)
def add(model, models_home):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment