Commit 85a57006 authored by Michael Yang's avatar Michael Yang
Browse files

check if name exists before create/pull/copy

parent c5e892cb
...@@ -420,13 +420,14 @@ func (s *Server) PullModelHandler(c *gin.Context) { ...@@ -420,13 +420,14 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return return
} }
var model string name := model.ParseName(cmp.Or(req.Model, req.Name))
if req.Model != "" { if !name.IsValid() {
model = req.Model c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
} else if req.Name != "" { return
model = req.Name }
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
...@@ -444,7 +445,7 @@ func (s *Server) PullModelHandler(c *gin.Context) { ...@@ -444,7 +445,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil { if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
...@@ -506,6 +507,21 @@ func (s *Server) PushModelHandler(c *gin.Context) { ...@@ -506,6 +507,21 @@ func (s *Server) PushModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func checkNameExists(name model.Name) error {
names, err := Manifests()
if err != nil {
return err
}
for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists")
}
}
return nil
}
func (s *Server) CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var req api.CreateRequest var req api.CreateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
...@@ -522,6 +538,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) { ...@@ -522,6 +538,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return return
} }
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Path == "" && req.Modelfile == "" { if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return return
...@@ -770,6 +791,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) { ...@@ -770,6 +791,11 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
return return
} }
if err := checkNameExists(dst); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) { if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil { } else if err != nil {
......
...@@ -21,35 +21,35 @@ import ( ...@@ -21,35 +21,35 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func Test_Routes(t *testing.T) { func createTestFile(t *testing.T, name string) string {
type testCase struct { t.Helper()
Name string
Method string
Path string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *http.Response)
}
createTestFile := func(t *testing.T, name string) string { f, err := os.CreateTemp(t.TempDir(), name)
t.Helper() assert.Nil(t, err)
defer f.Close()
f, err := os.CreateTemp(t.TempDir(), name) err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err) assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF")) err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err) assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3)) err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err) assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0)) err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err) assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0)) return f.Name()
assert.Nil(t, err) }
return f.Name() func Test_Routes(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *http.Response)
} }
createTestModel := func(t *testing.T, name string) { createTestModel := func(t *testing.T, name string) {
...@@ -237,3 +237,82 @@ func Test_Routes(t *testing.T) { ...@@ -237,3 +237,82 @@ func Test_Routes(t *testing.T) {
}) })
} }
} }
func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cases := []string{
"mistral",
"llama3:latest",
"library/phi3:q4_0",
"registry.ollama.ai/library/gemma:q5_K_M",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/alice/bob:latest",
}
var s Server
for _, tt := range cases {
t.Run(tt, func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 got %d", w.Code)
}
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
if err != nil {
t.Fatal(err)
}
t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullModelHandler, api.PullRequest{
Name: strings.ToUpper(tt),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
Source: tt,
Destination: strings.ToUpper(tt),
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
})
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment