Commit 01c31aac authored by Bruce MacDonald's avatar Bruce MacDonald
Browse files

consistency between generate and add naming

parent 8fc8a007
...@@ -79,14 +79,18 @@ def generate_oneshot(*args, **kwargs): ...@@ -79,14 +79,18 @@ def generate_oneshot(*args, **kwargs):
spinner = yaspin() spinner = yaspin()
spinner.start() spinner.start()
spinner_running = True spinner_running = True
for output in engine.generate(*args, **kwargs): try:
choices = output.get("choices", []) for output in engine.generate(*args, **kwargs):
if len(choices) > 0: choices = output.get("choices", [])
if spinner_running: if len(choices) > 0:
spinner.stop() if spinner_running:
spinner_running = False spinner.stop()
print("\r", end="") # move cursor back to beginning of line again spinner_running = False
print(choices[0].get("text", ""), end="", flush=True) print("\r", end="") # move cursor back to beginning of line again
print(choices[0].get("text", ""), end="", flush=True)
except Exception:
spinner.stop()
raise
# end with a new line # end with a new line
print(flush=True) print(flush=True)
......
import os from os import path, dup, dup2, devnull
import json
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from llama_cpp import Llama as LLM from llama_cpp import Llama as LLM
...@@ -10,12 +9,12 @@ import ollama.prompt ...@@ -10,12 +9,12 @@ import ollama.prompt
@contextmanager @contextmanager
def suppress_stderr(): def suppress_stderr():
stderr = os.dup(sys.stderr.fileno()) stderr = dup(sys.stderr.fileno())
with open(os.devnull, "w") as devnull: with open(devnull, "w") as devnull:
os.dup2(devnull.fileno(), sys.stderr.fileno()) dup2(devnull.fileno(), sys.stderr.fileno())
yield yield
os.dup2(stderr, sys.stderr.fileno()) dup2(stderr, sys.stderr.fileno())
def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
...@@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): ...@@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
def load(model, models_home=".", llms={}): def load(model, models_home=".", llms={}):
llm = llms.get(model, None) llm = llms.get(model, None)
if not llm: if not llm:
stored_model_path = os.path.join(models_home, model, ".bin") stored_model_path = path.join(models_home, model) + ".bin"
if os.path.exists(stored_model_path): if path.exists(stored_model_path):
model_path = stored_model_path model_path = stored_model_path
else: else:
# try loading this as a path to a model, rather than a model name # try loading this as a path to a model, rather than a model name
model_path = os.path.abspath(model) model_path = path.abspath(model)
if not path.exists(model_path):
raise Exception(f"Model not found: {model}")
try: try:
# suppress LLM's output # suppress LLM's output
......
import os
import requests import requests
import validators import validators
from os import path, walk
from urllib.parse import urlsplit, urlunsplit from urllib.parse import urlsplit, urlunsplit
from tqdm import tqdm from tqdm import tqdm
...@@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models' ...@@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models'
def models(models_home='.', *args, **kwargs): def models(models_home='.', *args, **kwargs):
for _, _, files in os.walk(models_home): for _, _, files in walk(models_home):
for file in files: for file in files:
base, ext = os.path.splitext(file) base, ext = path.splitext(file)
if ext == '.bin': if ext == '.bin':
yield base yield base
...@@ -27,7 +27,7 @@ def get_url_from_directory(model): ...@@ -27,7 +27,7 @@ def get_url_from_directory(model):
return model return model
def download_from_repo(url, models_home='.'): def download_from_repo(url, file_name, models_home='.'):
parts = urlsplit(url) parts = urlsplit(url)
path_parts = parts.path.split('/tree/') path_parts = parts.path.split('/tree/')
...@@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'): ...@@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'):
location, branch = path_parts location, branch = path_parts
location = location.strip('/') location = location.strip('/')
if file_name == '':
file_name = path.basename(location)
download_url = urlunsplit( download_url = urlunsplit(
( (
...@@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'): ...@@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'):
json_response = response.json() json_response = response.json()
download_url, file_size = find_bin_file(json_response, location, branch) download_url, file_size = find_bin_file(json_response, location, branch)
return download_file(download_url, models_home, location, file_size) return download_file(download_url, models_home, file_name, file_size)
def find_bin_file(json_response, location, branch): def find_bin_file(json_response, location, branch):
...@@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch): ...@@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch):
return download_url, file_size return download_url, file_size
def download_file(download_url, models_home, location, file_size): def download_file(download_url, models_home, file_name, file_size):
local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin' local_filename = path.join(models_home, file_name) + '.bin'
first_byte = ( first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
)
if first_byte >= file_size: if first_byte >= file_size:
return local_filename return local_filename
print(f'Pulling {os.path.basename(location)}...') print(f'Pulling {file_name}...')
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {} header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
...@@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size): ...@@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size):
def pull(model, models_home='.', *args, **kwargs): def pull(model, models_home='.', *args, **kwargs):
if os.path.exists(model): if path.exists(model):
# a file on the filesystem is being specified # a file on the filesystem is being specified
return model return model
# check the remote model location and see if it needs to be downloaded # check the remote model location and see if it needs to be downloaded
url = model url = model
file_name = ""
if not validators.url(url) and not url.startswith('huggingface.co'): if not validators.url(url) and not url.startswith('huggingface.co'):
url = get_url_from_directory(model) url = get_url_from_directory(model)
file_name = model
if not (url.startswith('http://') or url.startswith('https://')): if not (url.startswith('http://') or url.startswith('https://')):
url = f'https://{url}' url = f'https://{url}'
...@@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs): ...@@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs):
return model return model
raise Exception(f'Unknown model {model}') raise Exception(f'Unknown model {model}')
local_filename = download_from_repo(url, models_home) local_filename = download_from_repo(url, file_name, models_home)
return local_filename return local_filename
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment