Commit a30915bd authored by Michael Yang's avatar Michael Yang
Browse files

add capabilities

parent 58e3fff3
...@@ -34,6 +34,10 @@ import ( ...@@ -34,6 +34,10 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
type Capability string
const CapabilityCompletion = Capability("completion")
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
Username string Username string
...@@ -58,8 +62,20 @@ type Model struct { ...@@ -58,8 +62,20 @@ type Model struct {
Template *template.Template Template *template.Template
} }
func (m *Model) IsEmbedding() bool { func (m *Model) Has(caps ...Capability) bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") for _, cap := range caps {
switch cap {
case CapabilityCompletion:
if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") {
return false
}
default:
slog.Error("unknown capability", "capability", cap)
return false
}
}
return true
} }
func (m *Model) String() string { func (m *Model) String() string {
......
...@@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if model.IsEmbedding() { if !model.Has(CapabilityCompletion) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return return
} }
...@@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if model.IsEmbedding() { if !model.Has(CapabilityCompletion) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return return
} }
......
...@@ -61,8 +61,8 @@ func TestNamed(t *testing.T) { ...@@ -61,8 +61,8 @@ func TestNamed(t *testing.T) {
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
cases := []struct { cases := []struct {
template string template string
capabilities []string vars []string
}{ }{
{"{{ .Prompt }}", []string{"prompt"}}, {"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, {"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
...@@ -81,8 +81,8 @@ func TestParse(t *testing.T) { ...@@ -81,8 +81,8 @@ func TestParse(t *testing.T) {
} }
vars := tmpl.Vars() vars := tmpl.Vars()
if !slices.Equal(tt.capabilities, vars) { if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.capabilities, vars) t.Errorf("expected %v, got %v", tt.vars, vars)
} }
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment