Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
4c4c730a
Unverified
Commit
4c4c730a
authored
Jan 27, 2024
by
mraiser
Committed by
GitHub
Jan 27, 2024
Browse files
Merge branch 'ollama:main' into main
parents
6eb3cddc
e02ecfb6
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
203 additions
and
52 deletions
+203
-52
parser/parser_test.go
parser/parser_test.go
+35
-0
scripts/build_docker.sh
scripts/build_docker.sh
+10
-0
server/download.go
server/download.go
+74
-43
server/images.go
server/images.go
+47
-5
server/routes.go
server/routes.go
+37
-4
No files found.
parser/parser_test.go
View file @
4c4c730a
...
...
@@ -61,3 +61,38 @@ PARAMETER param1
assert
.
ErrorContains
(
t
,
err
,
"missing value for [param1]"
)
}
func
Test_Parser_Messages
(
t
*
testing
.
T
)
{
input
:=
`
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader
:=
strings
.
NewReader
(
input
)
commands
,
err
:=
Parse
(
reader
)
assert
.
Nil
(
t
,
err
)
expectedCommands
:=
[]
Command
{
{
Name
:
"model"
,
Args
:
"foo"
},
{
Name
:
"message"
,
Args
:
"system: You are a Parser. Always Parse things."
},
{
Name
:
"message"
,
Args
:
"user: Hey there!"
},
{
Name
:
"message"
,
Args
:
"assistant: Hello, I want to parse all the things!"
},
}
assert
.
Equal
(
t
,
expectedCommands
,
commands
)
}
func
Test_Parser_Messages_BadRole
(
t
*
testing
.
T
)
{
input
:=
`
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader
:=
strings
.
NewReader
(
input
)
_
,
err
:=
Parse
(
reader
)
assert
.
ErrorContains
(
t
,
err
,
"role must be one of
\"
system
\"
,
\"
user
\"
, or
\"
assistant
\"
"
)
}
scripts/build_docker.sh
View file @
4c4c730a
...
...
@@ -13,3 +13,13 @@ docker build \
-f
Dockerfile
\
-t
ollama/ollama:
$VERSION
\
.
docker build
\
--load
\
--platform
=
linux/amd64
\
--build-arg
=
VERSION
\
--build-arg
=
GOFLAGS
\
--target
runtime-rocm
\
-f
Dockerfile
\
-t
ollama/ollama:
$VERSION
-rocm
\
.
server/download.go
View file @
4c4c730a
...
...
@@ -25,6 +25,11 @@ import (
"github.com/jmorganca/ollama/format"
)
const
maxRetries
=
6
var
errMaxRetriesExceeded
=
errors
.
New
(
"max retries exceeded"
)
var
errPartStalled
=
errors
.
New
(
"part stalled"
)
var
blobDownloadManager
sync
.
Map
type
blobDownload
struct
{
...
...
@@ -44,10 +49,11 @@ type blobDownload struct {
}
type
blobDownloadPart
struct
{
N
int
Offset
int64
Size
int64
Completed
int64
N
int
Offset
int64
Size
int64
Completed
int64
lastUpdated
time
.
Time
*
blobDownload
`json:"-"`
}
...
...
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
return
p
.
Offset
+
p
.
Size
}
func
(
p
*
blobDownloadPart
)
Write
(
b
[]
byte
)
(
n
int
,
err
error
)
{
n
=
len
(
b
)
p
.
blobDownload
.
Completed
.
Add
(
int64
(
n
))
p
.
lastUpdated
=
time
.
Now
()
return
n
,
nil
}
func
(
b
*
blobDownload
)
Prepare
(
ctx
context
.
Context
,
requestURL
*
url
.
URL
,
opts
*
RegistryOptions
)
error
{
partFilePaths
,
err
:=
filepath
.
Glob
(
b
.
Name
+
"-partial-*"
)
if
err
!=
nil
{
...
...
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
case
errors
.
Is
(
err
,
context
.
Canceled
),
errors
.
Is
(
err
,
syscall
.
ENOSPC
)
:
// return immediately if the context is canceled or the device is out of space
return
err
case
errors
.
Is
(
err
,
errPartStalled
)
:
try
--
continue
case
err
!=
nil
:
sleep
:=
time
.
Second
*
time
.
Duration
(
math
.
Pow
(
2
,
float64
(
try
)))
slog
.
Info
(
fmt
.
Sprintf
(
"%s part %d attempt %d failed: %v, retrying in %s"
,
b
.
Digest
[
7
:
19
],
part
.
N
,
try
,
err
,
sleep
))
...
...
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
}
func
(
b
*
blobDownload
)
downloadChunk
(
ctx
context
.
Context
,
requestURL
*
url
.
URL
,
w
io
.
Writer
,
part
*
blobDownloadPart
,
opts
*
RegistryOptions
)
error
{
headers
:=
make
(
http
.
Header
)
headers
.
Set
(
"Range"
,
fmt
.
Sprintf
(
"bytes=%d-%d"
,
part
.
StartsAt
(),
part
.
StopsAt
()
-
1
))
resp
,
err
:=
makeRequestWithRetry
(
ctx
,
http
.
MethodGet
,
requestURL
,
headers
,
nil
,
opts
)
if
err
!=
nil
{
return
err
}
defer
resp
.
Body
.
Close
()
g
,
ctx
:=
errgroup
.
WithContext
(
ctx
)
g
.
Go
(
func
()
error
{
headers
:=
make
(
http
.
Header
)
headers
.
Set
(
"Range"
,
fmt
.
Sprintf
(
"bytes=%d-%d"
,
part
.
StartsAt
(),
part
.
StopsAt
()
-
1
))
resp
,
err
:=
makeRequestWithRetry
(
ctx
,
http
.
MethodGet
,
requestURL
,
headers
,
nil
,
opts
)
if
err
!=
nil
{
return
err
}
defer
resp
.
Body
.
Close
()
n
,
err
:=
io
.
Copy
(
w
,
io
.
TeeReader
(
resp
.
Body
,
b
))
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
context
.
Canceled
)
&&
!
errors
.
Is
(
err
,
io
.
ErrUnexpectedEOF
)
{
// rollback progress
b
.
Completed
.
Add
(
-
n
)
return
err
}
n
,
err
:=
io
.
Copy
(
w
,
io
.
TeeReader
(
resp
.
Body
,
part
))
if
err
!=
nil
&&
!
errors
.
Is
(
err
,
context
.
Canceled
)
&&
!
errors
.
Is
(
err
,
io
.
ErrUnexpectedEOF
)
{
// rollback progress
b
.
Completed
.
Add
(
-
n
)
return
err
}
part
.
Completed
+=
n
if
err
:=
b
.
writePart
(
part
.
Name
(),
part
);
err
!=
nil
{
part
.
Completed
+=
n
if
err
:=
b
.
writePart
(
part
.
Name
(),
part
);
err
!=
nil
{
return
err
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
return
err
}
})
g
.
Go
(
func
()
error
{
ticker
:=
time
.
NewTicker
(
time
.
Second
)
for
{
select
{
case
<-
ticker
.
C
:
if
part
.
Completed
>=
part
.
Size
{
return
nil
}
if
!
part
.
lastUpdated
.
IsZero
()
&&
time
.
Since
(
part
.
lastUpdated
)
>
5
*
time
.
Second
{
slog
.
Info
(
fmt
.
Sprintf
(
"%s part %d stalled; retrying"
,
b
.
Digest
[
7
:
19
],
part
.
N
))
// reset last updated
part
.
lastUpdated
=
time
.
Time
{}
return
errPartStalled
}
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
}
}
})
// return nil or context.Canceled or UnexpectedEOF (resumable)
return
err
return
g
.
Wait
()
}
func
(
b
*
blobDownload
)
newPart
(
offset
,
size
int64
)
error
{
...
...
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
return
json
.
NewEncoder
(
partFile
)
.
Encode
(
part
)
}
func
(
b
*
blobDownload
)
Write
(
p
[]
byte
)
(
n
int
,
err
error
)
{
n
=
len
(
p
)
b
.
Completed
.
Add
(
int64
(
n
))
return
n
,
nil
}
func
(
b
*
blobDownload
)
acquire
()
{
b
.
references
.
Add
(
1
)
}
...
...
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
for
{
select
{
case
<-
ticker
.
C
:
fn
(
api
.
ProgressResponse
{
Status
:
fmt
.
Sprintf
(
"pulling %s"
,
b
.
Digest
[
7
:
19
]),
Digest
:
b
.
Digest
,
Total
:
b
.
Total
,
Completed
:
b
.
Completed
.
Load
(),
})
if
b
.
done
||
b
.
err
!=
nil
{
return
b
.
err
}
case
<-
ctx
.
Done
()
:
return
ctx
.
Err
()
}
fn
(
api
.
ProgressResponse
{
Status
:
fmt
.
Sprintf
(
"pulling %s"
,
b
.
Digest
[
7
:
19
]),
Digest
:
b
.
Digest
,
Total
:
b
.
Total
,
Completed
:
b
.
Completed
.
Load
(),
})
if
b
.
done
||
b
.
err
!=
nil
{
return
b
.
err
}
}
}
...
...
@@ -303,10 +338,6 @@ type downloadOpts struct {
fn
func
(
api
.
ProgressResponse
)
}
const
maxRetries
=
6
var
errMaxRetriesExceeded
=
errors
.
New
(
"max retries exceeded"
)
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func
downloadBlob
(
ctx
context
.
Context
,
opts
downloadOpts
)
error
{
fp
,
err
:=
GetBlobsPath
(
opts
.
digest
)
...
...
server/images.go
View file @
4c4c730a
...
...
@@ -41,7 +41,7 @@ type Model struct {
Config
ConfigV2
ShortName
string
ModelPath
string
Original
Model
string
Parent
Model
string
AdapterPaths
[]
string
ProjectorPaths
[]
string
Template
string
...
...
@@ -50,6 +50,12 @@ type Model struct {
Digest
string
Size
int64
Options
map
[
string
]
interface
{}
Messages
[]
Message
}
type
Message
struct
{
Role
string
`json:"role"`
Content
string
`json:"content"`
}
type
PromptVars
struct
{
...
...
@@ -333,7 +339,7 @@ func GetModel(name string) (*Model, error) {
switch
layer
.
MediaType
{
case
"application/vnd.ollama.image.model"
:
model
.
ModelPath
=
filename
model
.
Original
Model
=
layer
.
From
model
.
Parent
Model
=
layer
.
From
case
"application/vnd.ollama.image.embed"
:
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
...
...
@@ -374,6 +380,16 @@ func GetModel(name string) (*Model, error) {
if
err
=
json
.
NewDecoder
(
params
)
.
Decode
(
&
model
.
Options
);
err
!=
nil
{
return
nil
,
err
}
case
"application/vnd.ollama.image.messages"
:
msgs
,
err
:=
os
.
Open
(
filename
)
if
err
!=
nil
{
return
nil
,
err
}
defer
msgs
.
Close
()
if
err
=
json
.
NewDecoder
(
msgs
)
.
Decode
(
&
model
.
Messages
);
err
!=
nil
{
return
nil
,
err
}
case
"application/vnd.ollama.image.license"
:
bts
,
err
:=
os
.
ReadFile
(
filename
)
if
err
!=
nil
{
...
...
@@ -428,12 +444,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
var
layers
Layers
messages
:=
[]
string
{}
params
:=
make
(
map
[
string
][]
string
)
fromParams
:=
make
(
map
[
string
]
any
)
for
_
,
c
:=
range
commands
{
slog
.
Info
(
fmt
.
Sprintf
(
"[%s] - %s"
,
c
.
Name
,
c
.
Args
))
mediatype
:=
fmt
.
Sprintf
(
"application/vnd.ollama.image.%s"
,
c
.
Name
)
switch
c
.
Name
{
...
...
@@ -607,11 +623,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
layers
.
Replace
(
layer
)
case
"message"
:
messages
=
append
(
messages
,
c
.
Args
)
default
:
params
[
c
.
Name
]
=
append
(
params
[
c
.
Name
],
c
.
Args
)
}
}
if
len
(
messages
)
>
0
{
fn
(
api
.
ProgressResponse
{
Status
:
"creating parameters layer"
})
msgs
:=
make
([]
api
.
Message
,
0
)
for
_
,
m
:=
range
messages
{
// todo: handle images
msg
:=
strings
.
SplitN
(
m
,
": "
,
2
)
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
msg
[
0
],
Content
:
msg
[
1
]})
}
var
b
bytes
.
Buffer
if
err
:=
json
.
NewEncoder
(
&
b
)
.
Encode
(
msgs
);
err
!=
nil
{
return
err
}
layer
,
err
:=
NewLayer
(
&
b
,
"application/vnd.ollama.image.messages"
)
if
err
!=
nil
{
return
err
}
layers
.
Replace
(
layer
)
}
if
len
(
params
)
>
0
{
fn
(
api
.
ProgressResponse
{
Status
:
"creating parameters layer"
})
...
...
@@ -908,8 +950,8 @@ func ShowModelfile(model *Model) (string, error) {
mt
.
Model
=
model
mt
.
From
=
model
.
ModelPath
if
model
.
Original
Model
!=
""
{
mt
.
From
=
model
.
Original
Model
if
model
.
Parent
Model
!=
""
{
mt
.
From
=
model
.
Parent
Model
}
modelFile
:=
`# Modelfile generated by "ollama show"
...
...
server/routes.go
View file @
4c4c730a
...
...
@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
return
}
sessionDuration
:=
defaultSessionDuration
var
sessionDuration
time
.
Duration
if
req
.
KeepAlive
==
nil
{
sessionDuration
=
defaultSessionDuration
}
else
{
sessionDuration
=
req
.
KeepAlive
.
Duration
}
if
err
:=
load
(
c
,
model
,
opts
,
sessionDuration
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
...
...
@@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
sessionDuration
:=
defaultSessionDuration
var
sessionDuration
time
.
Duration
if
req
.
KeepAlive
==
nil
{
sessionDuration
=
defaultSessionDuration
}
else
{
sessionDuration
=
req
.
KeepAlive
.
Duration
}
if
err
:=
load
(
c
,
model
,
opts
,
sessionDuration
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
...
...
@@ -659,6 +672,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
modelDetails
:=
api
.
ModelDetails
{
ParentModel
:
model
.
ParentModel
,
Format
:
model
.
Config
.
ModelFormat
,
Family
:
model
.
Config
.
ModelFamily
,
Families
:
model
.
Config
.
ModelFamilies
,
...
...
@@ -674,11 +688,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model
.
Template
=
req
.
Template
}
msgs
:=
make
([]
api
.
Message
,
0
)
for
_
,
msg
:=
range
model
.
Messages
{
msgs
=
append
(
msgs
,
api
.
Message
{
Role
:
msg
.
Role
,
Content
:
msg
.
Content
})
}
resp
:=
&
api
.
ShowResponse
{
License
:
strings
.
Join
(
model
.
License
,
"
\n
"
),
System
:
model
.
System
,
Template
:
model
.
Template
,
Details
:
modelDetails
,
Messages
:
msgs
,
}
var
params
[]
string
...
...
@@ -1067,7 +1087,14 @@ func ChatHandler(c *gin.Context) {
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
}
sessionDuration
:=
defaultSessionDuration
var
sessionDuration
time
.
Duration
if
req
.
KeepAlive
==
nil
{
sessionDuration
=
defaultSessionDuration
}
else
{
sessionDuration
=
req
.
KeepAlive
.
Duration
}
if
err
:=
load
(
c
,
model
,
opts
,
sessionDuration
);
err
!=
nil
{
c
.
JSON
(
http
.
StatusInternalServerError
,
gin
.
H
{
"error"
:
err
.
Error
()})
return
...
...
@@ -1075,7 +1102,13 @@ func ChatHandler(c *gin.Context) {
// an empty request loads the model
if
len
(
req
.
Messages
)
==
0
{
c
.
JSON
(
http
.
StatusOK
,
api
.
ChatResponse
{
CreatedAt
:
time
.
Now
()
.
UTC
(),
Model
:
req
.
Model
,
Done
:
true
,
Message
:
api
.
Message
{
Role
:
"assistant"
}})
resp
:=
api
.
ChatResponse
{
CreatedAt
:
time
.
Now
()
.
UTC
(),
Model
:
req
.
Model
,
Done
:
true
,
Message
:
api
.
Message
{
Role
:
"assistant"
},
}
c
.
JSON
(
http
.
StatusOK
,
resp
)
return
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment