"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "41082eb0681f8ac0f8b6a00d154d9c258ec8e49a"
Unverified Commit 5610405e authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

Merge pull request #11 from jmorganca/interactive-generate

interactive generate
parents 536ff4d8 db550820
import os import os
import sys
import json import json
from pathlib import Path from pathlib import Path
from argparse import ArgumentParser from argparse import ArgumentParser
...@@ -20,7 +21,7 @@ def main(): ...@@ -20,7 +21,7 @@ def main():
generate_parser = subparsers.add_parser("generate") generate_parser = subparsers.add_parser("generate")
generate_parser.add_argument("model") generate_parser.add_argument("model")
generate_parser.add_argument("prompt") generate_parser.add_argument("prompt", nargs="?")
generate_parser.set_defaults(fn=generate) generate_parser.set_defaults(fn=generate)
add_parser = subparsers.add_parser("add") add_parser = subparsers.add_parser("add")
...@@ -37,6 +38,8 @@ def main(): ...@@ -37,6 +38,8 @@ def main():
try: try:
fn = args.pop("fn") fn = args.pop("fn")
fn(**args) fn(**args)
except KeyboardInterrupt:
pass
except KeyError: except KeyError:
parser.print_help() parser.print_help()
except Exception as e: except Exception as e:
...@@ -49,12 +52,37 @@ def list_models(*args, **kwargs): ...@@ -49,12 +52,37 @@ def list_models(*args, **kwargs):
def generate(*args, **kwargs): def generate(*args, **kwargs):
if prompt := kwargs.get('prompt'):
print('>>>', prompt, flush=True)
print(flush=True)
generate_oneshot(*args, **kwargs)
print(flush=True)
return
return generate_interactive(*args, **kwargs)
def generate_oneshot(*args, **kwargs):
for output in engine.generate(*args, **kwargs): for output in engine.generate(*args, **kwargs):
output = json.loads(output) output = json.loads(output)
choices = output.get("choices", []) choices = output.get("choices", [])
if len(choices) > 0: if len(choices) > 0:
print(choices[0].get("text", ""), end="") print(choices[0].get("text", ""), end="", flush=True)
print()
def generate_interactive(*args, **kwargs):
print('>>> ', end='', flush=True)
for line in sys.stdin:
if not sys.stdin.isatty():
print(line, end='')
print(flush=True)
kwargs.update({'prompt': line})
generate_oneshot(*args, **kwargs)
print(flush=True)
print('>>> ', end='', flush=True)
def add(model, models_home): def add(model, models_home):
......
...@@ -27,7 +27,7 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): ...@@ -27,7 +27,7 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
kwargs.update({"max_tokens": 16384}) kwargs.update({"max_tokens": 16384})
if "stop" not in kwargs: if "stop" not in kwargs:
kwargs.update({"stop": ["Q:", "\n"]}) kwargs.update({"stop": ["Q:"]})
if "stream" not in kwargs: if "stream" not in kwargs:
kwargs.update({"stream": True}) kwargs.update({"stream": True})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment