Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
01c31aac
Commit
01c31aac
authored
Jun 29, 2023
by
Bruce MacDonald
Browse files
consistency between generate and add naming
parent
8fc8a007
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
30 deletions
+38
-30
ollama/cmd/cli.py
ollama/cmd/cli.py
+12
-8
ollama/engine.py
ollama/engine.py
+11
-9
ollama/model.py
ollama/model.py
+15
-13
No files found.
ollama/cmd/cli.py
View file @
01c31aac
...
@@ -79,14 +79,18 @@ def generate_oneshot(*args, **kwargs):
...
@@ -79,14 +79,18 @@ def generate_oneshot(*args, **kwargs):
spinner
=
yaspin
()
spinner
=
yaspin
()
spinner
.
start
()
spinner
.
start
()
spinner_running
=
True
spinner_running
=
True
for
output
in
engine
.
generate
(
*
args
,
**
kwargs
):
try
:
choices
=
output
.
get
(
"choices"
,
[])
for
output
in
engine
.
generate
(
*
args
,
**
kwargs
):
if
len
(
choices
)
>
0
:
choices
=
output
.
get
(
"choices"
,
[])
if
spinner_running
:
if
len
(
choices
)
>
0
:
spinner
.
stop
()
if
spinner_running
:
spinner_running
=
False
spinner
.
stop
()
print
(
"
\r
"
,
end
=
""
)
# move cursor back to beginning of line again
spinner_running
=
False
print
(
choices
[
0
].
get
(
"text"
,
""
),
end
=
""
,
flush
=
True
)
print
(
"
\r
"
,
end
=
""
)
# move cursor back to beginning of line again
print
(
choices
[
0
].
get
(
"text"
,
""
),
end
=
""
,
flush
=
True
)
except
Exception
:
spinner
.
stop
()
raise
# end with a new line
# end with a new line
print
(
flush
=
True
)
print
(
flush
=
True
)
...
...
ollama/engine.py
View file @
01c31aac
import
os
from
os
import
path
,
dup
,
dup2
,
devnull
import
json
import
sys
import
sys
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
llama_cpp
import
Llama
as
LLM
from
llama_cpp
import
Llama
as
LLM
...
@@ -10,12 +9,12 @@ import ollama.prompt
...
@@ -10,12 +9,12 @@ import ollama.prompt
@
contextmanager
@
contextmanager
def
suppress_stderr
():
def
suppress_stderr
():
stderr
=
os
.
dup
(
sys
.
stderr
.
fileno
())
stderr
=
dup
(
sys
.
stderr
.
fileno
())
with
open
(
os
.
devnull
,
"w"
)
as
devnull
:
with
open
(
devnull
,
"w"
)
as
devnull
:
os
.
dup2
(
devnull
.
fileno
(),
sys
.
stderr
.
fileno
())
dup2
(
devnull
.
fileno
(),
sys
.
stderr
.
fileno
())
yield
yield
os
.
dup2
(
stderr
,
sys
.
stderr
.
fileno
())
dup2
(
stderr
,
sys
.
stderr
.
fileno
())
def
generate
(
model
,
prompt
,
models_home
=
"."
,
llms
=
{},
*
args
,
**
kwargs
):
def
generate
(
model
,
prompt
,
models_home
=
"."
,
llms
=
{},
*
args
,
**
kwargs
):
...
@@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
...
@@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
def
load
(
model
,
models_home
=
"."
,
llms
=
{}):
def
load
(
model
,
models_home
=
"."
,
llms
=
{}):
llm
=
llms
.
get
(
model
,
None
)
llm
=
llms
.
get
(
model
,
None
)
if
not
llm
:
if
not
llm
:
stored_model_path
=
os
.
path
.
join
(
models_home
,
model
,
".bin"
)
stored_model_path
=
path
.
join
(
models_home
,
model
)
+
".bin"
if
os
.
path
.
exists
(
stored_model_path
):
if
path
.
exists
(
stored_model_path
):
model_path
=
stored_model_path
model_path
=
stored_model_path
else
:
else
:
# try loading this as a path to a model, rather than a model name
# try loading this as a path to a model, rather than a model name
model_path
=
os
.
path
.
abspath
(
model
)
model_path
=
path
.
abspath
(
model
)
if
not
path
.
exists
(
model_path
):
raise
Exception
(
f
"Model not found:
{
model
}
"
)
try
:
try
:
# suppress LLM's output
# suppress LLM's output
...
...
ollama/model.py
View file @
01c31aac
import
os
import
requests
import
requests
import
validators
import
validators
from
os
import
path
,
walk
from
urllib.parse
import
urlsplit
,
urlunsplit
from
urllib.parse
import
urlsplit
,
urlunsplit
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models'
...
@@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models'
def
models
(
models_home
=
'.'
,
*
args
,
**
kwargs
):
def
models
(
models_home
=
'.'
,
*
args
,
**
kwargs
):
for
_
,
_
,
files
in
os
.
walk
(
models_home
):
for
_
,
_
,
files
in
walk
(
models_home
):
for
file
in
files
:
for
file
in
files
:
base
,
ext
=
os
.
path
.
splitext
(
file
)
base
,
ext
=
path
.
splitext
(
file
)
if
ext
==
'.bin'
:
if
ext
==
'.bin'
:
yield
base
yield
base
...
@@ -27,7 +27,7 @@ def get_url_from_directory(model):
...
@@ -27,7 +27,7 @@ def get_url_from_directory(model):
return
model
return
model
def
download_from_repo
(
url
,
models_home
=
'.'
):
def
download_from_repo
(
url
,
file_name
,
models_home
=
'.'
):
parts
=
urlsplit
(
url
)
parts
=
urlsplit
(
url
)
path_parts
=
parts
.
path
.
split
(
'/tree/'
)
path_parts
=
parts
.
path
.
split
(
'/tree/'
)
...
@@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'):
...
@@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'):
location
,
branch
=
path_parts
location
,
branch
=
path_parts
location
=
location
.
strip
(
'/'
)
location
=
location
.
strip
(
'/'
)
if
file_name
==
''
:
file_name
=
path
.
basename
(
location
)
download_url
=
urlunsplit
(
download_url
=
urlunsplit
(
(
(
...
@@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'):
...
@@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'):
json_response
=
response
.
json
()
json_response
=
response
.
json
()
download_url
,
file_size
=
find_bin_file
(
json_response
,
location
,
branch
)
download_url
,
file_size
=
find_bin_file
(
json_response
,
location
,
branch
)
return
download_file
(
download_url
,
models_home
,
location
,
file_size
)
return
download_file
(
download_url
,
models_home
,
file_name
,
file_size
)
def
find_bin_file
(
json_response
,
location
,
branch
):
def
find_bin_file
(
json_response
,
location
,
branch
):
...
@@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch):
...
@@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch):
return
download_url
,
file_size
return
download_url
,
file_size
def
download_file
(
download_url
,
models_home
,
location
,
file_size
):
def
download_file
(
download_url
,
models_home
,
file_name
,
file_size
):
local_filename
=
os
.
path
.
join
(
models_home
,
os
.
path
.
basename
(
location
)
)
+
'.bin'
local_filename
=
path
.
join
(
models_home
,
file_name
)
+
'.bin'
first_byte
=
(
first_byte
=
path
.
getsize
(
local_filename
)
if
path
.
exists
(
local_filename
)
else
0
os
.
path
.
getsize
(
local_filename
)
if
os
.
path
.
exists
(
local_filename
)
else
0
)
if
first_byte
>=
file_size
:
if
first_byte
>=
file_size
:
return
local_filename
return
local_filename
print
(
f
'Pulling
{
os
.
path
.
basename
(
location
)
}
...'
)
print
(
f
'Pulling
{
file_name
}
...'
)
header
=
{
'Range'
:
f
'bytes=
{
first_byte
}
-'
}
if
first_byte
!=
0
else
{}
header
=
{
'Range'
:
f
'bytes=
{
first_byte
}
-'
}
if
first_byte
!=
0
else
{}
...
@@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size):
...
@@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size):
def
pull
(
model
,
models_home
=
'.'
,
*
args
,
**
kwargs
):
def
pull
(
model
,
models_home
=
'.'
,
*
args
,
**
kwargs
):
if
os
.
path
.
exists
(
model
):
if
path
.
exists
(
model
):
# a file on the filesystem is being specified
# a file on the filesystem is being specified
return
model
return
model
# check the remote model location and see if it needs to be downloaded
# check the remote model location and see if it needs to be downloaded
url
=
model
url
=
model
file_name
=
""
if
not
validators
.
url
(
url
)
and
not
url
.
startswith
(
'huggingface.co'
):
if
not
validators
.
url
(
url
)
and
not
url
.
startswith
(
'huggingface.co'
):
url
=
get_url_from_directory
(
model
)
url
=
get_url_from_directory
(
model
)
file_name
=
model
if
not
(
url
.
startswith
(
'http://'
)
or
url
.
startswith
(
'https://'
)):
if
not
(
url
.
startswith
(
'http://'
)
or
url
.
startswith
(
'https://'
)):
url
=
f
'https://
{
url
}
'
url
=
f
'https://
{
url
}
'
...
@@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs):
...
@@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs):
return
model
return
model
raise
Exception
(
f
'Unknown model
{
model
}
'
)
raise
Exception
(
f
'Unknown model
{
model
}
'
)
local_filename
=
download_from_repo
(
url
,
models_home
)
local_filename
=
download_from_repo
(
url
,
file_name
,
models_home
)
return
local_filename
return
local_filename
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment