"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a137e4f494ee2fec4e5582464a6874682992580d"
Unverified Commit b1fd7fef authored by Blake Mizerany's avatar Blake Mizerany Committed by GitHub
Browse files

server: more support for mixed-case model names (#8017)

Fixes #7944
parent 36d111e7
...@@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error { ...@@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
var data [][]string var data [][]string
for _, m := range models.Models { for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) { if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
} }
} }
......
...@@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio ...@@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
switch command { switch command {
case "model", "adapter": case "model", "adapter":
if name := model.ParseName(c.Args); name.IsValid() && command == "model" { if name := model.ParseName(c.Args); name.IsValid() && command == "model" {
name, err := getExistingName(name)
if err != nil {
return err
}
baseLayers, err = parseFromModel(ctx, name, fn) baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil { if err != nil {
return err return err
......
...@@ -3,6 +3,7 @@ package server ...@@ -3,6 +3,7 @@ package server
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/fs"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
...@@ -10,6 +11,7 @@ import ( ...@@ -10,6 +11,7 @@ import (
"strings" "strings"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
) )
type ModelPath struct { type ModelPath struct {
...@@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string { ...@@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string {
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) { func (mp ModelPath) GetManifestPath() (string, error) {
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) { name := model.Name{
return filepath.Join(envconfig.Models(), "manifests", p), nil Host: mp.Registry,
Namespace: mp.Namespace,
Model: mp.Repository,
Tag: mp.Tag,
} }
if !name.IsValid() {
return "", errModelPathInvalid return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
} }
func (mp ModelPath) BaseURL() *url.URL { func (mp ModelPath) BaseURL() *url.URL {
......
package server package server
import ( import (
"errors"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
...@@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) { ...@@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) {
}) })
} }
} }
func TestInsecureModelpath(t *testing.T) {
mp := ParseModelPath("../../..:something")
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
t.Errorf("expected error: %v", err)
}
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log/slog" "log/slog"
"math" "math"
"net" "net"
...@@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
model, err := GetModel(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
model, err := GetModel(name.String())
if err != nil { if err != nil {
switch { switch {
case os.IsNotExist(err): case errors.Is(err, fs.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name": case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
...@@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert) caps = append(caps, CapabilityInsert)
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return return
...@@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) { ...@@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) name, err := getExistingName(model.ParseName(req.Model))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) handleScheduleError(c, req.Model, err)
return return
...@@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { ...@@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil { if err != nil {
handleScheduleError(c, req.Model, err) handleScheduleError(c, req.Model, err)
return return
...@@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) { ...@@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
return return
} }
var model string var mname string
if req.Model != "" { if req.Model != "" {
model = req.Model mname = req.Model
} else if req.Name != "" { } else if req.Name != "" {
model = req.Name mname = req.Name
} else { } else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
...@@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) { ...@@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil { name, err := getExistingName(model.ParseName(mname))
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
...@@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) { ...@@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
// getExistingName returns the original, on disk name if the input name is a // getExistingName searches the models directory for the longest prefix match of
// case-insensitive match, otherwise it returns the input name. // the input name and returns the input name with all existing parts replaced
// with each part found. If no parts are found, the input name is returned as
// is.
func getExistingName(n model.Name) (model.Name, error) { func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name var zero model.Name
existing, err := Manifests(true) existing, err := Manifests(true)
if err != nil { if err != nil {
return zero, err return zero, err
} }
var set model.Name // tracks parts already canonicalized
for e := range existing { for e := range existing {
if n.EqualFold(e) { if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
return e, nil n.Host = e.Host
}
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
n.Tag = e.Tag
} }
} }
return n, nil return n, nil
...@@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) { ...@@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
} }
if r.Path == "" && r.Modelfile == "" { if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"})
return return
} }
...@@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) { ...@@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return return
} }
n, err := getExistingName(n)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
}
m, err := ParseNamedManifest(n) m, err := ParseNamedManifest(n)
if err != nil { if err != nil {
switch { switch {
...@@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) { ...@@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
} }
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m, err := GetModel(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, errModelPathInvalid
}
name, err := getExistingName(name)
if err != nil {
return nil, err
}
m, err := GetModel(name.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ...@@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content} msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
} }
n := model.ParseName(req.Model) manifest, err := ParseNamedManifest(name)
if !n.IsValid() {
return nil, errors.New("invalid model name")
}
manifest, err := ParseNamedManifest(n)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps = append(caps, CapabilityTools) caps = append(caps, CapabilityTools)
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return return
......
...@@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) { ...@@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 400, got %d", w.Code) t.Errorf("expected status 400, got %d", w.Code)
} }
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" { if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })
......
...@@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) { ...@@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) {
wantStableName := name() wantStableName := name()
t.Logf("stable name: %s", wantStableName)
// checkManifestList tests that there is strictly one manifest in the // checkManifestList tests that there is strictly one manifest in the
// models directory, and that the manifest is for the model under test. // models directory, and that the manifest is for the model under test.
checkManifestList := func() { checkManifestList := func() {
...@@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) { ...@@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) {
Destination: name(), Destination: name(),
})) }))
checkManifestList() checkManifestList()
t.Logf("pushing")
rr := createRequest(t, s.PushHandler, api.PushRequest{
Model: name(),
Insecure: true,
Username: "alice",
Password: "x",
})
checkOK(rr)
if !strings.Contains(rr.Body.String(), `"status":"success"`) {
t.Errorf("got = %q, want success", rr.Body.String())
}
} }
func TestShow(t *testing.T) { func TestShow(t *testing.T) {
......
...@@ -223,12 +223,12 @@ func (n Name) String() string { ...@@ -223,12 +223,12 @@ func (n Name) String() string {
func (n Name) DisplayShortest() string { func (n Name) DisplayShortest() string {
var sb strings.Builder var sb strings.Builder
if n.Host != defaultHost { if !strings.EqualFold(n.Host, defaultHost) {
sb.WriteString(n.Host) sb.WriteString(n.Host)
sb.WriteByte('/') sb.WriteByte('/')
sb.WriteString(n.Namespace) sb.WriteString(n.Namespace)
sb.WriteByte('/') sb.WriteByte('/')
} else if n.Namespace != defaultNamespace { } else if !strings.EqualFold(n.Namespace, defaultNamespace) {
sb.WriteString(n.Namespace) sb.WriteString(n.Namespace)
sb.WriteByte('/') sb.WriteByte('/')
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment