Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
orangecat
ollama
Commits
b1fd7fef
Unverified
Commit
b1fd7fef
authored
Dec 11, 2024
by
Blake Mizerany
Committed by
GitHub
Dec 11, 2024
Browse files
server: more support for mixed-case model names (#8017)
Fixes #7944
parent
36d111e7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
123 additions
and
38 deletions
+123
-38
cmd/cmd.go
cmd/cmd.go
+1
-1
server/images.go
server/images.go
+4
-0
server/modelpath.go
server/modelpath.go
+11
-4
server/modelpath_test.go
server/modelpath_test.go
+0
-8
server/routes.go
server/routes.go
+90
-22
server/routes_generate_test.go
server/routes_generate_test.go
+1
-1
server/routes_test.go
server/routes_test.go
+14
-0
types/model/name.go
types/model/name.go
+2
-2
No files found.
cmd/cmd.go
View file @
b1fd7fef
...
...
@@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
var
data
[][]
string
for
_
,
m
:=
range
models
.
Models
{
if
len
(
args
)
==
0
||
strings
.
HasPrefix
(
m
.
Name
,
args
[
0
])
{
if
len
(
args
)
==
0
||
strings
.
HasPrefix
(
strings
.
ToLower
(
m
.
Name
)
,
strings
.
ToLower
(
args
[
0
])
)
{
data
=
append
(
data
,
[]
string
{
m
.
Name
,
m
.
Digest
[
:
12
],
format
.
HumanBytes
(
m
.
Size
),
format
.
HumanTime
(
m
.
ModifiedAt
,
"Never"
)})
}
}
...
...
server/images.go
View file @
b1fd7fef
...
...
@@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
switch
command
{
case
"model"
,
"adapter"
:
if
name
:=
model
.
ParseName
(
c
.
Args
);
name
.
IsValid
()
&&
command
==
"model"
{
name
,
err
:=
getExistingName
(
name
)
if
err
!=
nil
{
return
err
}
baseLayers
,
err
=
parseFromModel
(
ctx
,
name
,
fn
)
if
err
!=
nil
{
return
err
...
...
server/modelpath.go
View file @
b1fd7fef
...
...
@@ -3,6 +3,7 @@ package server
import
(
"errors"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
...
...
@@ -10,6 +11,7 @@ import (
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
type
ModelPath
struct
{
...
...
@@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string {
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func
(
mp
ModelPath
)
GetManifestPath
()
(
string
,
error
)
{
if
p
:=
filepath
.
Join
(
mp
.
Registry
,
mp
.
Namespace
,
mp
.
Repository
,
mp
.
Tag
);
filepath
.
IsLocal
(
p
)
{
return
filepath
.
Join
(
envconfig
.
Models
(),
"manifests"
,
p
),
nil
name
:=
model
.
Name
{
Host
:
mp
.
Registry
,
Namespace
:
mp
.
Namespace
,
Model
:
mp
.
Repository
,
Tag
:
mp
.
Tag
,
}
return
""
,
errModelPathInvalid
if
!
name
.
IsValid
()
{
return
""
,
fs
.
ErrNotExist
}
return
filepath
.
Join
(
envconfig
.
Models
(),
"manifests"
,
name
.
Filepath
()),
nil
}
func
(
mp
ModelPath
)
BaseURL
()
*
url
.
URL
{
...
...
server/modelpath_test.go
View file @
b1fd7fef
package
server
import
(
"errors"
"os"
"path/filepath"
"testing"
...
...
@@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) {
})
}
}
func
TestInsecureModelpath
(
t
*
testing
.
T
)
{
mp
:=
ParseModelPath
(
"../../..:something"
)
if
_
,
err
:=
mp
.
GetManifestPath
();
!
errors
.
Is
(
err
,
errModelPathInvalid
)
{
t
.
Errorf
(
"expected error: %v"
,
err
)
}
}
server/routes.go
View file @
b1fd7fef
...
...
@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
...
...
@@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model
,
err
:=
GetModel
(
req
.
Model
)
name
:=
model
.
ParseName
(
req
.
Model
)
if
!
name
.
IsValid
()
{
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found"
,
req
.
Model
)})
return
}
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name
,
err
:=
getExistingName
(
name
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found"
,
req
.
Model
)})
return
}
model
,
err
:=
GetModel
(
name
.
String
())
if
err
!=
nil
{
switch
{
case
os
.
Is
NotExist
(
err
)
:
case
errors
.
Is
(
err
,
fs
.
Err
NotExist
)
:
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found"
,
req
.
Model
)})
case
err
.
Error
()
==
"invalid model name"
:
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
err
.
Error
()})
...
...
@@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps
=
append
(
caps
,
CapabilityInsert
)
}
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
caps
,
req
.
Options
,
req
.
KeepAlive
)
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
name
.
String
()
,
caps
,
req
.
Options
,
req
.
KeepAlive
)
if
errors
.
Is
(
err
,
errCapabilityCompletion
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support generate"
,
req
.
Model
)})
return
...
...
@@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
}
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
[]
Capability
{},
req
.
Options
,
req
.
KeepAlive
)
name
,
err
:=
getExistingName
(
model
.
ParseName
(
req
.
Model
))
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found"
,
req
.
Model
)})
return
}
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
name
.
String
(),
[]
Capability
{},
req
.
Options
,
req
.
KeepAlive
)
if
err
!=
nil
{
handleScheduleError
(
c
,
req
.
Model
,
err
)
return
...
...
@@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
r
,
_
,
_
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
[]
Capability
{},
req
.
Options
,
req
.
KeepAlive
)
name
:=
model
.
ParseName
(
req
.
Model
)
if
!
name
.
IsValid
()
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"model is required"
})
return
}
r
,
_
,
_
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
name
.
String
(),
[]
Capability
{},
req
.
Options
,
req
.
KeepAlive
)
if
err
!=
nil
{
handleScheduleError
(
c
,
req
.
Model
,
err
)
return
...
...
@@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
return
}
var
m
odel
string
var
m
name
string
if
req
.
Model
!=
""
{
m
odel
=
req
.
Model
m
name
=
req
.
Model
}
else
if
req
.
Name
!=
""
{
m
odel
=
req
.
Name
m
name
=
req
.
Name
}
else
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"model is required"
})
return
...
...
@@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
ctx
,
cancel
:=
context
.
WithCancel
(
c
.
Request
.
Context
())
defer
cancel
()
if
err
:=
PushModel
(
ctx
,
model
,
regOpts
,
fn
);
err
!=
nil
{
name
,
err
:=
getExistingName
(
model
.
ParseName
(
mname
))
if
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
return
}
if
err
:=
PushModel
(
ctx
,
name
.
DisplayShortest
(),
regOpts
,
fn
);
err
!=
nil
{
ch
<-
gin
.
H
{
"error"
:
err
.
Error
()}
}
}()
...
...
@@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
streamResponse
(
c
,
ch
)
}
// getExistingName returns the original, on disk name if the input name is a
// case-insensitive match, otherwise it returns the input name.
// getExistingName searches the models directory for the longest prefix match of
// the input name and returns the input name with all existing parts replaced
// with each part found. If no parts are found, the input name is returned as
// is.
func
getExistingName
(
n
model
.
Name
)
(
model
.
Name
,
error
)
{
var
zero
model
.
Name
existing
,
err
:=
Manifests
(
true
)
if
err
!=
nil
{
return
zero
,
err
}
var
set
model
.
Name
// tracks parts already canonicalized
for
e
:=
range
existing
{
if
n
.
EqualFold
(
e
)
{
return
e
,
nil
if
set
.
Host
==
""
&&
strings
.
EqualFold
(
e
.
Host
,
n
.
Host
)
{
n
.
Host
=
e
.
Host
}
if
set
.
Namespace
==
""
&&
strings
.
EqualFold
(
e
.
Namespace
,
n
.
Namespace
)
{
n
.
Namespace
=
e
.
Namespace
}
if
set
.
Model
==
""
&&
strings
.
EqualFold
(
e
.
Model
,
n
.
Model
)
{
n
.
Model
=
e
.
Model
}
if
set
.
Tag
==
""
&&
strings
.
EqualFold
(
e
.
Tag
,
n
.
Tag
)
{
n
.
Tag
=
e
.
Tag
}
}
return
n
,
nil
...
...
@@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if
r
.
Path
==
""
&&
r
.
Modelfile
==
""
{
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"path or
m
odelfile are required"
})
c
.
AbortWithStatusJSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"path or
M
odelfile are required"
})
return
}
...
...
@@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
n
,
err
:=
getExistingName
(
n
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusNotFound
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"model '%s' not found"
,
cmp
.
Or
(
r
.
Model
,
r
.
Name
))})
return
}
m
,
err
:=
ParseNamedManifest
(
n
)
if
err
!=
nil
{
switch
{
...
...
@@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
}
func
GetModelInfo
(
req
api
.
ShowRequest
)
(
*
api
.
ShowResponse
,
error
)
{
m
,
err
:=
GetModel
(
req
.
Model
)
name
:=
model
.
ParseName
(
req
.
Model
)
if
!
name
.
IsValid
()
{
return
nil
,
errModelPathInvalid
}
name
,
err
:=
getExistingName
(
name
)
if
err
!=
nil
{
return
nil
,
err
}
m
,
err
:=
GetModel
(
name
.
String
())
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs
[
i
]
=
api
.
Message
{
Role
:
msg
.
Role
,
Content
:
msg
.
Content
}
}
n
:=
model
.
ParseName
(
req
.
Model
)
if
!
n
.
IsValid
()
{
return
nil
,
errors
.
New
(
"invalid model name"
)
}
manifest
,
err
:=
ParseNamedManifest
(
n
)
manifest
,
err
:=
ParseNamedManifest
(
name
)
if
err
!=
nil
{
return
nil
,
err
}
...
...
@@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps
=
append
(
caps
,
CapabilityTools
)
}
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
req
.
Model
,
caps
,
req
.
Options
,
req
.
KeepAlive
)
name
:=
model
.
ParseName
(
req
.
Model
)
if
!
name
.
IsValid
()
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"model is required"
})
return
}
name
,
err
:=
getExistingName
(
name
)
if
err
!=
nil
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
"model is required"
})
return
}
r
,
m
,
opts
,
err
:=
s
.
scheduleRunner
(
c
.
Request
.
Context
(),
name
.
String
(),
caps
,
req
.
Options
,
req
.
KeepAlive
)
if
errors
.
Is
(
err
,
errCapabilityCompletion
)
{
c
.
JSON
(
http
.
StatusBadRequest
,
gin
.
H
{
"error"
:
fmt
.
Sprintf
(
"%q does not support chat"
,
req
.
Model
)})
return
...
...
server/routes_generate_test.go
View file @
b1fd7fef
...
...
@@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) {
t
.
Errorf
(
"expected status 400, got %d"
,
w
.
Code
)
}
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"test does not support insert"}`
);
diff
!=
""
{
if
diff
:=
cmp
.
Diff
(
w
.
Body
.
String
(),
`{"error":"
registry.ollama.ai/library/test:la
test does not support insert"}`
);
diff
!=
""
{
t
.
Errorf
(
"mismatch (-got +want):
\n
%s"
,
diff
)
}
})
...
...
server/routes_test.go
View file @
b1fd7fef
...
...
@@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) {
wantStableName
:=
name
()
t
.
Logf
(
"stable name: %s"
,
wantStableName
)
// checkManifestList tests that there is strictly one manifest in the
// models directory, and that the manifest is for the model under test.
checkManifestList
:=
func
()
{
...
...
@@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) {
Destination
:
name
(),
}))
checkManifestList
()
t
.
Logf
(
"pushing"
)
rr
:=
createRequest
(
t
,
s
.
PushHandler
,
api
.
PushRequest
{
Model
:
name
(),
Insecure
:
true
,
Username
:
"alice"
,
Password
:
"x"
,
})
checkOK
(
rr
)
if
!
strings
.
Contains
(
rr
.
Body
.
String
(),
`"status":"success"`
)
{
t
.
Errorf
(
"got = %q, want success"
,
rr
.
Body
.
String
())
}
}
func
TestShow
(
t
*
testing
.
T
)
{
...
...
types/model/name.go
View file @
b1fd7fef
...
...
@@ -223,12 +223,12 @@ func (n Name) String() string {
func
(
n
Name
)
DisplayShortest
()
string
{
var
sb
strings
.
Builder
if
n
.
Host
!=
defaultHost
{
if
!
strings
.
EqualFold
(
n
.
Host
,
defaultHost
)
{
sb
.
WriteString
(
n
.
Host
)
sb
.
WriteByte
(
'/'
)
sb
.
WriteString
(
n
.
Namespace
)
sb
.
WriteByte
(
'/'
)
}
else
if
n
.
Namespace
!=
defaultNamespace
{
}
else
if
!
strings
.
EqualFold
(
n
.
Namespace
,
defaultNamespace
)
{
sb
.
WriteString
(
n
.
Namespace
)
sb
.
WriteByte
(
'/'
)
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment