Commit a11cddbf authored by Bruce MacDonald's avatar Bruce MacDonald
Browse files

remove models home param

parent 54a94566
import os import os
import sys import sys
from pathlib import Path
from argparse import ArgumentParser from argparse import ArgumentParser
from yaspin import yaspin from yaspin import yaspin
...@@ -10,12 +9,9 @@ from ollama.cmd import server ...@@ -10,12 +9,9 @@ from ollama.cmd import server
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models")
# create models home if it doesn't exist # create models home if it doesn't exist
models_home = parser.parse_known_args()[0].models_home os.makedirs(model.models_home, exist_ok=True)
if not models_home.exists():
os.makedirs(models_home)
subparsers = parser.add_subparsers() subparsers = parser.add_subparsers()
......
...@@ -11,7 +11,7 @@ def set_parser(parser): ...@@ -11,7 +11,7 @@ def set_parser(parser):
parser.set_defaults(fn=serve) parser.set_defaults(fn=serve)
def serve(models_home=".", *args, **kwargs): def serve(*args, **kwargs):
app = web.Application() app = web.Application()
cors = aiohttp_cors.setup( cors = aiohttp_cors.setup(
...@@ -39,7 +39,6 @@ def serve(models_home=".", *args, **kwargs): ...@@ -39,7 +39,6 @@ def serve(models_home=".", *args, **kwargs):
app.update( app.update(
{ {
"llms": {}, "llms": {},
"models_home": models_home,
} }
) )
...@@ -54,7 +53,6 @@ async def load(request): ...@@ -54,7 +53,6 @@ async def load(request):
kwargs = { kwargs = {
"llms": request.app.get("llms"), "llms": request.app.get("llms"),
"models_home": request.app.get("models_home"),
} }
engine.load(model, **kwargs) engine.load(model, **kwargs)
...@@ -86,7 +84,6 @@ async def generate(request): ...@@ -86,7 +84,6 @@ async def generate(request):
kwargs = { kwargs = {
"llms": request.app.get("llms"), "llms": request.app.get("llms"),
"models_home": request.app.get("models_home"),
} }
for output in engine.generate(model, prompt, **kwargs): for output in engine.generate(model, prompt, **kwargs):
......
...@@ -18,8 +18,8 @@ def suppress_stderr(): ...@@ -18,8 +18,8 @@ def suppress_stderr():
os.dup2(stderr, sys.stderr.fileno()) os.dup2(stderr, sys.stderr.fileno())
def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): def generate(model, prompt, llms={}, *args, **kwargs):
llm = load(model, models_home=models_home, llms=llms) llm = load(model, llms=llms)
prompt = ollama.prompt.template(model, prompt) prompt = ollama.prompt.template(model, prompt)
if "max_tokens" not in kwargs: if "max_tokens" not in kwargs:
...@@ -35,10 +35,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs): ...@@ -35,10 +35,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
yield output yield output
def load(model, models_home=".", llms={}): def load(model, llms={}):
llm = llms.get(model, None) llm = llms.get(model, None)
if not llm: if not llm:
stored_model_path = path.join(models_home, model) + ".bin" stored_model_path = path.join(ollama.model.models_home, model) + ".bin"
if path.exists(stored_model_path): if path.exists(stored_model_path):
model_path = stored_model_path model_path = stored_model_path
else: else:
......
import requests import requests
import validators import validators
from pathlib import Path
from os import path, walk from os import path, walk
from urllib.parse import urlsplit, urlunsplit from urllib.parse import urlsplit, urlunsplit
from tqdm import tqdm from tqdm import tqdm
models_endpoint_url = 'https://ollama.ai/api/models' models_endpoint_url = 'https://ollama.ai/api/models'
models_home = path.join(Path.home(), '.ollama', 'models')
def models(models_home='.', *args, **kwargs): def models(*args, **kwargs):
for _, _, files in walk(models_home): for _, _, files in walk(models_home):
for file in files: for file in files:
base, ext = path.splitext(file) base, ext = path.splitext(file)
...@@ -27,7 +29,7 @@ def get_url_from_directory(model): ...@@ -27,7 +29,7 @@ def get_url_from_directory(model):
return model return model
def download_from_repo(url, file_name, models_home='.'): def download_from_repo(url, file_name):
parts = urlsplit(url) parts = urlsplit(url)
path_parts = parts.path.split('/tree/') path_parts = parts.path.split('/tree/')
...@@ -55,7 +57,7 @@ def download_from_repo(url, file_name, models_home='.'): ...@@ -55,7 +57,7 @@ def download_from_repo(url, file_name, models_home='.'):
json_response = response.json() json_response = response.json()
download_url, file_size = find_bin_file(json_response, location, branch) download_url, file_size = find_bin_file(json_response, location, branch)
return download_file(download_url, models_home, file_name, file_size) return download_file(download_url, file_name, file_size)
def find_bin_file(json_response, location, branch): def find_bin_file(json_response, location, branch):
...@@ -75,7 +77,7 @@ def find_bin_file(json_response, location, branch): ...@@ -75,7 +77,7 @@ def find_bin_file(json_response, location, branch):
return download_url, file_size return download_url, file_size
def download_file(download_url, models_home, file_name, file_size): def download_file(download_url, file_name, file_size):
local_filename = path.join(models_home, file_name) + '.bin' local_filename = path.join(models_home, file_name) + '.bin'
first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0 first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
...@@ -108,7 +110,7 @@ def download_file(download_url, models_home, file_name, file_size): ...@@ -108,7 +110,7 @@ def download_file(download_url, models_home, file_name, file_size):
return local_filename return local_filename
def pull(model, models_home='.', *args, **kwargs): def pull(model, *args, **kwargs):
if path.exists(model): if path.exists(model):
# a file on the filesystem is being specified # a file on the filesystem is being specified
return model return model
...@@ -128,6 +130,6 @@ def pull(model, models_home='.', *args, **kwargs): ...@@ -128,6 +130,6 @@ def pull(model, models_home='.', *args, **kwargs):
return model return model
raise Exception(f'Unknown model {model}') raise Exception(f'Unknown model {model}')
local_filename = download_from_repo(url, file_name, models_home) local_filename = download_from_repo(url, file_name)
return local_filename return local_filename
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment