Unverified Commit dddb58a3 authored by Michael Yang's avatar Michael Yang Committed by GitHub
Browse files

Merge pull request #5051 from ollama/mxyng/capabilities

add model capabilities
parents 400056e1 da8e2a04
...@@ -28,11 +28,16 @@ import ( ...@@ -28,11 +28,16 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
type Capability string
const CapabilityCompletion = Capability("completion")
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
Username string Username string
...@@ -48,16 +53,43 @@ type Model struct { ...@@ -48,16 +53,43 @@ type Model struct {
ParentModel string ParentModel string
AdapterPaths []string AdapterPaths []string
ProjectorPaths []string ProjectorPaths []string
Template string
System string System string
License []string License []string
Digest string Digest string
Options map[string]interface{} Options map[string]interface{}
Messages []Message Messages []Message
Template *template.Template
} }
func (m *Model) IsEmbedding() bool { func (m *Model) Has(caps ...Capability) bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") for _, cap := range caps {
switch cap {
case CapabilityCompletion:
f, err := os.Open(m.ModelPath)
if err != nil {
slog.Error("couldn't open model file", "error", err)
continue
}
defer f.Close()
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
ggml, _, err := llm.DecodeGGML(f, 0)
if err != nil {
slog.Error("couldn't decode ggml", "error", err)
continue
}
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
return false
}
default:
slog.Error("unknown capability", "capability", cap)
return false
}
}
return true
} }
func (m *Model) String() string { func (m *Model) String() string {
...@@ -82,10 +114,10 @@ func (m *Model) String() string { ...@@ -82,10 +114,10 @@ func (m *Model) String() string {
}) })
} }
if m.Template != "" { if m.Template != nil {
modelfile.Commands = append(modelfile.Commands, parser.Command{ modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "template", Name: "template",
Args: m.Template, Args: m.Template.String(),
}) })
} }
...@@ -135,13 +167,6 @@ type Message struct { ...@@ -135,13 +167,6 @@ type Message struct {
Content string `json:"content"` Content string `json:"content"`
} }
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
}
type ConfigV2 struct { type ConfigV2 struct {
ModelFormat string `json:"model_format"` ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"` ModelFamily string `json:"model_family"`
...@@ -160,7 +185,7 @@ type RootFS struct { ...@@ -160,7 +185,7 @@ type RootFS struct {
DiffIDs []string `json:"diff_ids"` DiffIDs []string `json:"diff_ids"`
} }
func GetManifest(mp ModelPath) (*ManifestV2, string, error) { func GetManifest(mp ModelPath) (*Manifest, string, error) {
fp, err := mp.GetManifestPath() fp, err := mp.GetManifestPath()
if err != nil { if err != nil {
return nil, "", err return nil, "", err
...@@ -170,7 +195,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) { ...@@ -170,7 +195,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
return nil, "", err return nil, "", err
} }
var manifest *ManifestV2 var manifest *Manifest
bts, err := os.ReadFile(fp) bts, err := os.ReadFile(fp)
if err != nil { if err != nil {
...@@ -198,8 +223,7 @@ func GetModel(name string) (*Model, error) { ...@@ -198,8 +223,7 @@ func GetModel(name string) (*Model, error) {
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(), ShortName: mp.GetShortTagname(),
Digest: digest, Digest: digest,
Template: "{{ .Prompt }}", Template: template.DefaultTemplate,
License: []string{},
} }
filename, err := GetBlobsPath(manifest.Config.Digest) filename, err := GetBlobsPath(manifest.Config.Digest)
...@@ -235,27 +259,24 @@ func GetModel(name string) (*Model, error) { ...@@ -235,27 +259,24 @@ func GetModel(name string) (*Model, error) {
model.AdapterPaths = append(model.AdapterPaths, filename) model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector": case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename) model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.Template = string(bts) model.Template, err = template.Parse(string(bts))
case "application/vnd.ollama.image.system":
bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.system":
model.System = string(bts)
case "application/vnd.ollama.image.prompt":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model.Template = string(bts) model.System = string(bts)
case "application/vnd.ollama.image.params": case "application/vnd.ollama.image.params":
params, err := os.Open(filename) params, err := os.Open(filename)
if err != nil { if err != nil {
...@@ -822,7 +843,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu ...@@ -822,7 +843,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
var manifest *ManifestV2 var manifest *Manifest
var err error var err error
var noprune string var noprune string
...@@ -929,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu ...@@ -929,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil return nil
} }
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) { func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header) headers := make(http.Header)
...@@ -940,7 +961,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio ...@@ -940,7 +961,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
} }
defer resp.Body.Close() defer resp.Body.Close()
var m *ManifestV2 var m *Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err return nil, err
} }
......
...@@ -14,7 +14,10 @@ import ( ...@@ -14,7 +14,10 @@ import (
) )
type Manifest struct { type Manifest struct {
ManifestV2 SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Layers []*Layer `json:"layers"`
filepath string filepath string
fi os.FileInfo fi os.FileInfo
...@@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { ...@@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
p := filepath.Join(manifests, n.Filepath()) p := filepath.Join(manifests, n.Filepath())
var m ManifestV2 var m Manifest
f, err := os.Open(p) f, err := os.Open(p)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) { ...@@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err return nil, err
} }
return &Manifest{ m.filepath = p
ManifestV2: m, m.fi = fi
filepath: p, m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)), return &m, nil
}, nil
} }
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
...@@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { ...@@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
} }
defer f.Close() defer f.Close()
m := ManifestV2{ m := Manifest{
SchemaVersion: 2, SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json", MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config, Config: config,
......
...@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) { ...@@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
} }
defer f.Close() defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil { if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }
......
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/templates" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
...@@ -256,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap ...@@ -256,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) { func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers { for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" { if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := templates.NamedTemplate(s); err != nil { if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err) slog.Debug("template detection", "error", err)
} else { } else {
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
......
...@@ -4,10 +4,11 @@ import ( ...@@ -4,10 +4,11 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"strings" "strings"
"text/template"
"text/template/parse" "text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
// isResponseNode checks if the node contains .Response // isResponseNode checks if the node contains .Response
...@@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) { ...@@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) {
// Prompt renders a prompt from a template. If generate is set to true, // Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered // the response and parts of the template following it are not rendered
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) { func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl) formatTemplateForResponse(tmpl, generate)
if err != nil {
return "", err
}
formatTemplateForResponse(parsed, generate)
vars := map[string]any{ vars := map[string]any{
"System": system, "System": system,
...@@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error ...@@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error
} }
var sb strings.Builder var sb strings.Builder
if err := parsed.Execute(&sb, vars); err != nil { if err := tmpl.Execute(&sb, vars); err != nil {
return "", err return "", err
} }
return sb.String(), nil return sb.String(), nil
} }
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false) rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil { if err != nil {
return 0, err return 0, err
...@@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc ...@@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
} }
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct { type prompt struct {
System string System string
Prompt string Prompt string
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
func TestPrompt(t *testing.T) { func TestPrompt(t *testing.T) {
...@@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) { ...@@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate) tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil { if err != nil {
t.Errorf("error = %v", err) t.Errorf("error = %v", err)
} }
...@@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) { ...@@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) tmpl, err := template.Parse(tc.template)
if err != nil {
t.Fatal(err)
}
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
if err != nil { if err != nil {
t.Errorf("error = %v", err) t.Errorf("error = %v", err)
} }
......
...@@ -31,6 +31,7 @@ import ( ...@@ -31,6 +31,7 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
...@@ -121,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -121,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if model.IsEmbedding() { if !model.Has(CapabilityCompletion) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return return
} }
...@@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
tmpl, err := template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
var prompt string var prompt string
...@@ -169,7 +176,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -169,7 +176,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt = req.Prompt prompt = req.Prompt
case req.Prompt != "": case req.Prompt != "":
if req.Template == "" { if req.Template == "" {
req.Template = model.Template model.Template, err = template.Parse(req.Template)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
if req.System == "" { if req.System == "" {
...@@ -187,7 +198,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -187,7 +198,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
sb.WriteString(req.Prompt) sb.WriteString(req.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true) p, err := Prompt(tmpl, req.System, sb.String(), "", true)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -242,7 +253,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -242,7 +253,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw { if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false) p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
...@@ -680,7 +691,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ...@@ -680,7 +691,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
if req.Template != "" { if req.Template != "" {
m.Template = req.Template m.Template, err = template.Parse(req.Template)
if err != nil {
return nil, err
}
} }
msgs := make([]api.Message, 0) msgs := make([]api.Message, 0)
...@@ -701,7 +715,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ...@@ -701,7 +715,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
resp := &api.ShowResponse{ resp := &api.ShowResponse{
License: strings.Join(m.License, "\n"), License: strings.Join(m.License, "\n"),
System: m.System, System: m.System,
Template: m.Template, Template: m.Template.String(),
Details: modelDetails, Details: modelDetails,
Messages: msgs, Messages: msgs,
ModifiedAt: manifest.fi.ModTime(), ModifiedAt: manifest.fi.ModTime(),
...@@ -1248,7 +1262,7 @@ func (s *Server) ProcessHandler(c *gin.Context) { ...@@ -1248,7 +1262,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
} }
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) { func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) { encode := func(s string) ([]int, error) {
return runner.llama.Tokenize(ctx, s) return runner.llama.Tokenize(ctx, s)
} }
...@@ -1296,8 +1310,8 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1296,8 +1310,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if model.IsEmbedding() { if !model.Has(CapabilityCompletion) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return return
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment