Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
885403b4
Commit
885403b4
authored
Jun 28, 2023
by
Bruce MacDonald
Browse files
ollama run command
parent
77eddba5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
69 additions
and
18 deletions
+69
-18
ollama/cmd/cli.py
ollama/cmd/cli.py
+16
-5
ollama/model.py
ollama/model.py
+25
-13
poetry.lock
poetry.lock
+27
-0
pyproject.toml
pyproject.toml
+1
-0
No files found.
ollama/cmd/cli.py
View file @
885403b4
...
@@ -29,9 +29,13 @@ def main():
...
@@ -29,9 +29,13 @@ def main():
add_parser
.
set_defaults
(
fn
=
add
)
add_parser
.
set_defaults
(
fn
=
add
)
pull_parser
=
subparsers
.
add_parser
(
"pull"
)
pull_parser
=
subparsers
.
add_parser
(
"pull"
)
pull_parser
.
add_argument
(
"
remote
"
)
pull_parser
.
add_argument
(
"
model
"
)
pull_parser
.
set_defaults
(
fn
=
pull
)
pull_parser
.
set_defaults
(
fn
=
pull
)
pull_parser
=
subparsers
.
add_parser
(
"run"
)
pull_parser
.
add_argument
(
"model"
)
pull_parser
.
set_defaults
(
fn
=
run
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
=
vars
(
args
)
args
=
vars
(
args
)
...
@@ -52,8 +56,8 @@ def list_models(*args, **kwargs):
...
@@ -52,8 +56,8 @@ def list_models(*args, **kwargs):
def
generate
(
*
args
,
**
kwargs
):
def
generate
(
*
args
,
**
kwargs
):
if
prompt
:
=
kwargs
.
get
(
'
prompt
'
):
if
prompt
:
=
kwargs
.
get
(
"
prompt
"
):
print
(
'
>>>
'
,
prompt
,
flush
=
True
)
print
(
"
>>>
"
,
prompt
,
flush
=
True
)
generate_oneshot
(
*
args
,
**
kwargs
)
generate_oneshot
(
*
args
,
**
kwargs
)
return
return
...
@@ -79,7 +83,7 @@ def generate_oneshot(*args, **kwargs):
...
@@ -79,7 +83,7 @@ def generate_oneshot(*args, **kwargs):
def
generate_interactive
(
*
args
,
**
kwargs
):
def
generate_interactive
(
*
args
,
**
kwargs
):
while
True
:
while
True
:
print
(
'
>>>
'
,
end
=
''
,
flush
=
True
)
print
(
"
>>>
"
,
end
=
""
,
flush
=
True
)
line
=
next
(
sys
.
stdin
)
line
=
next
(
sys
.
stdin
)
if
not
line
:
if
not
line
:
return
return
...
@@ -90,7 +94,7 @@ def generate_interactive(*args, **kwargs):
...
@@ -90,7 +94,7 @@ def generate_interactive(*args, **kwargs):
def
generate_batch
(
*
args
,
**
kwargs
):
def
generate_batch
(
*
args
,
**
kwargs
):
for
line
in
sys
.
stdin
:
for
line
in
sys
.
stdin
:
print
(
'
>>>
'
,
line
,
end
=
''
,
flush
=
True
)
print
(
"
>>>
"
,
line
,
end
=
""
,
flush
=
True
)
kwargs
.
update
({
"prompt"
:
line
})
kwargs
.
update
({
"prompt"
:
line
})
generate_oneshot
(
*
args
,
**
kwargs
)
generate_oneshot
(
*
args
,
**
kwargs
)
...
@@ -101,3 +105,10 @@ def add(model, models_home):
...
@@ -101,3 +105,10 @@ def add(model, models_home):
def
pull
(
*
args
,
**
kwargs
):
def
pull
(
*
args
,
**
kwargs
):
model
.
pull
(
*
args
,
**
kwargs
)
model
.
pull
(
*
args
,
**
kwargs
)
def
run
(
*
args
,
**
kwargs
):
name
=
model
.
pull
(
*
args
,
**
kwargs
)
kwargs
.
update
({
"model"
:
name
})
print
(
f
"Running
{
name
}
..."
)
generate
(
*
args
,
**
kwargs
)
ollama/model.py
View file @
885403b4
import
os
import
os
import
requests
import
requests
import
validators
from
urllib.parse
import
urlsplit
,
urlunsplit
from
urllib.parse
import
urlsplit
,
urlunsplit
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -12,33 +13,36 @@ def models(models_home=".", *args, **kwargs):
...
@@ -12,33 +13,36 @@ def models(models_home=".", *args, **kwargs):
yield
base
,
os
.
path
.
join
(
root
,
file
)
yield
base
,
os
.
path
.
join
(
root
,
file
)
def
pull
(
remote
,
models_home
=
"."
,
*
args
,
**
kwargs
):
def
pull
(
model
,
models_home
=
"."
,
*
args
,
**
kwargs
):
if
not
(
remote
.
startswith
(
"http://"
)
or
remote
.
startswith
(
"https://"
)):
url
=
model
remote
=
f
"https://
{
remote
}
"
if
not
(
url
.
startswith
(
"http://"
)
or
url
.
startswith
(
"https://"
)):
url
=
f
"https://
{
url
}
"
parts
=
urlsplit
(
remote
)
parts
=
urlsplit
(
url
)
path_parts
=
parts
.
path
.
split
(
"/tree/"
)
path_parts
=
parts
.
path
.
split
(
"/tree/"
)
if
len
(
path_parts
)
==
1
:
if
len
(
path_parts
)
==
1
:
mode
l
=
path_parts
[
0
]
ur
l
=
path_parts
[
0
]
branch
=
"main"
branch
=
"main"
else
:
else
:
mode
l
,
branch
=
path_parts
ur
l
,
branch
=
path_parts
model
=
mode
l
.
strip
(
"/"
)
url
=
ur
l
.
strip
(
"/"
)
# Reconstruct the URL
# Reconstruct the URL
new_url
=
urlunsplit
(
new_url
=
urlunsplit
(
(
(
"https"
,
"https"
,
parts
.
netloc
,
parts
.
netloc
,
f
"/api/models/
{
mode
l
}
/tree/
{
branch
}
"
,
f
"/api/models/
{
ur
l
}
/tree/
{
branch
}
"
,
parts
.
query
,
parts
.
query
,
parts
.
fragment
,
parts
.
fragment
,
)
)
)
)
print
(
f
"Pulling
{
parts
.
netloc
}
/
{
model
}
..."
)
if
not
validators
.
url
(
new_url
):
# this may just be a local model location
return
model
response
=
requests
.
get
(
new_url
)
response
=
requests
.
get
(
new_url
)
response
.
raise_for_status
()
# Raises stored HTTPError, if one occurred
response
.
raise_for_status
()
# Raises stored HTTPError, if one occurred
...
@@ -47,22 +51,28 @@ def pull(remote, models_home=".", *args, **kwargs):
...
@@ -47,22 +51,28 @@ def pull(remote, models_home=".", *args, **kwargs):
# get the last bin file we find, this is probably the most up to date
# get the last bin file we find, this is probably the most up to date
download_url
=
None
download_url
=
None
file_size
=
0
for
file_info
in
json_response
:
for
file_info
in
json_response
:
if
file_info
.
get
(
"type"
)
==
"file"
and
file_info
.
get
(
"path"
).
endswith
(
".bin"
):
if
file_info
.
get
(
"type"
)
==
"file"
and
file_info
.
get
(
"path"
).
endswith
(
".bin"
):
f_path
=
file_info
.
get
(
"path"
)
f_path
=
file_info
.
get
(
"path"
)
download_url
=
f
"https://huggingface.co/
{
model
}
/resolve/
{
branch
}
/
{
f_path
}
"
download_url
=
f
"https://huggingface.co/
{
url
}
/resolve/
{
branch
}
/
{
f_path
}
"
file_size
=
file_info
.
get
(
"size"
)
if
download_url
is
None
:
if
download_url
is
None
:
raise
Exception
(
"No model found"
)
raise
Exception
(
"No model found"
)
local_filename
=
os
.
path
.
join
(
models_home
,
os
.
path
.
basename
(
mode
l
))
+
".bin"
local_filename
=
os
.
path
.
join
(
models_home
,
os
.
path
.
basename
(
ur
l
))
+
".bin"
# Check if file already exists
# Check if file already exists
first_byte
=
0
if
os
.
path
.
exists
(
local_filename
):
if
os
.
path
.
exists
(
local_filename
):
# TODO: check if the file is the same SHA
# TODO: check if the file is the same SHA
first_byte
=
os
.
path
.
getsize
(
local_filename
)
first_byte
=
os
.
path
.
getsize
(
local_filename
)
else
:
first_byte
=
0
if
first_byte
>=
file_size
:
return
local_filename
print
(
f
"Pulling
{
parts
.
netloc
}
/
{
model
}
..."
)
# If file size is non-zero, resume download
# If file size is non-zero, resume download
if
first_byte
!=
0
:
if
first_byte
!=
0
:
...
@@ -87,3 +97,5 @@ def pull(remote, models_home=".", *args, **kwargs):
...
@@ -87,3 +97,5 @@ def pull(remote, models_home=".", *args, **kwargs):
for
data
in
response
.
iter_content
(
chunk_size
=
1024
):
for
data
in
response
.
iter_content
(
chunk_size
=
1024
):
size
=
file
.
write
(
data
)
size
=
file
.
write
(
data
)
bar
.
update
(
size
)
bar
.
update
(
size
)
return
local_filename
poetry.lock
View file @
885403b4
...
@@ -271,6 +271,17 @@ files = [
...
@@ -271,6 +271,17 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
]
[[package]]
name = "decorator"
version = "5.1.1"
description = "Decorators for Humans"
optional = false
python-versions = ">=3.5"
files = [
{file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
]
[[package]]
[[package]]
name = "diskcache"
name = "diskcache"
version = "5.6.1"
version = "5.6.1"
...
@@ -659,6 +670,22 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.
...
@@ -659,6 +670,22 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "validators"
version = "0.20.0"
description = "Python Data Validation for Humans™."
optional = false
python-versions = ">=3.4"
files = [
{file = "validators-0.20.0.tar.gz", hash = "sha256:24148ce4e64100a2d5e267233e23e7afeb55316b47d30faae7eb6e7292bc226a"},
]
[package.dependencies]
decorator = ">=3.4.0"
[package.extras]
test = ["flake8 (>=2.4.0)", "isort (>=4.2.2)", "pytest (>=2.2.3)"]
[[package]]
[[package]]
name = "yarl"
name = "yarl"
version = "1.9.2"
version = "1.9.2"
...
...
pyproject.toml
View file @
885403b4
...
@@ -16,6 +16,7 @@ aiohttp = {version = "^3.8.4", optional = true}
...
@@ -16,6 +16,7 @@ aiohttp = {version = "^3.8.4", optional = true}
aiohttp-cors
=
{
version
=
"^0.7.0"
,
optional
=
true
}
aiohttp-cors
=
{
version
=
"^0.7.0"
,
optional
=
true
}
requests
=
"^2.31.0"
requests
=
"^2.31.0"
tqdm
=
"^4.65.0"
tqdm
=
"^4.65.0"
validators
=
"^0.20.0"
[tool.poetry.extras]
[tool.poetry.extras]
server
=
[
"aiohttp"
,
"aiohttp_cors"
]
server
=
[
"aiohttp"
,
"aiohttp_cors"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment