Commit b2b270ad authored by Devon Rifkin's avatar Devon Rifkin
Browse files

Merge branch 'main' into drifkin/array-head-count-simple

parents 20c5fd39 2bb69b40
...@@ -136,8 +136,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp ...@@ -136,8 +136,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt), slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
"used", numPast, "remaining", int32(len(prompt))-numPast) "used", numPast, "remaining", int32(len(prompt))-numPast)
slot.Inputs = prompt[:numPast]
prompt = prompt[numPast:] prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, nil return slot, prompt, nil
} }
......
...@@ -3,7 +3,6 @@ package ollamarunner ...@@ -3,7 +3,6 @@ package ollamarunner
import ( import (
"errors" "errors"
"fmt" "fmt"
"image"
"testing" "testing"
"time" "time"
...@@ -12,10 +11,6 @@ import ( ...@@ -12,10 +11,6 @@ import (
) )
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
tests := []struct { tests := []struct {
name string name string
t1 []input.Input t1 []input.Input
...@@ -36,20 +31,20 @@ func TestCountCommon(t *testing.T) { ...@@ -36,20 +31,20 @@ func TestCountCommon(t *testing.T) {
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{MultimodalHash: 1}},
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
expected: 2, expected: 2,
}, },
{ {
name: "Mixed, Same Length", name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
expected: 1, expected: 1,
}, },
{ {
......
package ollamarunner
import (
"errors"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// Tensors can't be used across multiple compute graphs. This is a problem
// if a single embedding is split across batches using views since all of
// the views will have the same source tensor. We also don't want to
// recompute the entire embedding for each batch.
//
// To avoid this, we compute all of the tensors for the embedding on the
// first use and then store the result in system memory. When we need
// additional tensors, we recreate them from the stored data.
// multimodalEntry represents the embeddings of a single object (such
// as an image).
type multimodalEntry struct {
// mm is the original set of tensors created by EncodeMultimodal
mm []input.Multimodal
// data is the computed result of mm. Nil if not yet computed
data [][]float32
}
// multimodalStore maps from an individual tensor (of which there
// may be many in a single multimodal object) to its parent embedding
type multimodalStore map[ml.Tensor]*multimodalEntry
func newMultimodalStore() multimodalStore {
return make(multimodalStore)
}
// addMultimodal stores an embedding for later use in a compute graph
func (m multimodalStore) addMultimodal(embedding []input.Multimodal) {
entry := &multimodalEntry{mm: embedding}
for _, e := range embedding {
if e.Tensor != nil {
m[e.Tensor] = entry
}
}
}
// getMultimodal takes a source set of tensors (which may contain a whole or
// parts of one or more images) and returns the equivalent that can be used in
// the current context
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
out := make([]input.Multimodal, len(in))
for i := range out {
if in[i].Tensor != nil {
var err error
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
if err != nil {
return nil, err
}
}
out[i].Data = in[i].Data
}
return out, nil
}
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
entry := m[in]
if entry.data == nil {
computeCtx := backend.NewContext()
defer computeCtx.Close()
var tensors []ml.Tensor
for _, t := range entry.mm {
if t.Tensor != nil {
tensors = append(tensors, t.Tensor)
}
}
if len(tensors) == 0 {
return nil, nil
}
computeCtx.Forward(tensors...)
entry.data = make([][]float32, len(entry.mm))
if !reserve {
computeCtx.Compute(tensors...)
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
}
}
} else {
computeCtx.Reserve()
}
}
for i, t := range entry.mm {
if in == t.Tensor {
if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil
} else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
}
}
}
return nil, errors.New("multimodal tensor not found")
}
package ollamarunner package ollamarunner
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"hash/maphash" "hash/maphash"
"image"
"log" "log"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"regexp" "regexp"
"runtime" "runtime"
"strconv" "strconv"
...@@ -21,10 +22,13 @@ import ( ...@@ -21,10 +22,13 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"golang.org/x/image/bmp"
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
...@@ -39,6 +43,9 @@ type Sequence struct { ...@@ -39,6 +43,9 @@ type Sequence struct {
// multimodal embeddings // multimodal embeddings
ctxs []ml.Context ctxs []ml.Context
// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
mmStore multimodalStore
// batch index // batch index
iBatch int iBatch int
...@@ -100,7 +107,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe ...@@ -100,7 +107,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
startTime := time.Now() startTime := time.Now()
inputs, ctxs, err := s.inputs(prompt, images) inputs, ctxs, mmStore, err := s.inputs(prompt, images)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err) return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 { } else if len(inputs) == 0 {
...@@ -155,6 +162,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe ...@@ -155,6 +162,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
return &Sequence{ return &Sequence{
ctxs: ctxs, ctxs: ctxs,
mmStore: mmStore,
inputs: inputs, inputs: inputs,
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
...@@ -173,9 +181,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe ...@@ -173,9 +181,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, error) { func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
var inputs []input.Input var inputs []input.Input
var ctxs []ml.Context var ctxs []ml.Context
var mmStore multimodalStore
var parts []string var parts []string
var matches [][]string var matches [][]string
...@@ -186,6 +195,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ ...@@ -186,6 +195,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
re := regexp.MustCompile(`\[img-(\d+)\]`) re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1) parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1) matches = re.FindAllStringSubmatch(prompt, -1)
mmStore = newMultimodalStore()
} else { } else {
parts = []string{prompt} parts = []string{prompt}
} }
...@@ -195,7 +205,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ ...@@ -195,7 +205,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
// text - tokenize // text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
for _, t := range tokens { for _, t := range tokens {
...@@ -215,7 +225,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ ...@@ -215,7 +225,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
} }
if imageIndex < 0 { if imageIndex < 0 {
return nil, nil, fmt.Errorf("invalid image index: %d", n) return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
} }
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
...@@ -223,13 +233,15 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ ...@@ -223,13 +233,15 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
ctxs = append(ctxs, ctx) ctxs = append(ctxs, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
s.multimodalHash.Reset() s.multimodalHash.Reset()
_, _ = s.multimodalHash.Write(images[imageIndex].Data) _, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64() imageHash := s.multimodalHash.Sum64()
mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true postTokenize = true
} }
...@@ -239,11 +251,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ ...@@ -239,11 +251,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
var err error var err error
inputs, err = multimodalProcessor.PostTokenize(inputs) inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
} }
return inputs, ctxs, nil return inputs, ctxs, mmStore, nil
} }
type Server struct { type Server struct {
...@@ -362,6 +374,9 @@ func (s *Server) processBatch() error { ...@@ -362,6 +374,9 @@ func (s *Server) processBatch() error {
} }
defer s.mu.Unlock() defer s.mu.Unlock()
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var batchInputs []int32 var batchInputs []int32
var batch input.Batch var batch input.Batch
...@@ -432,7 +447,11 @@ func (s *Server) processBatch() error { ...@@ -432,7 +447,11 @@ func (s *Server) processBatch() error {
batchInputs = append(batchInputs, inp.Token) batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil { if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal}) mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
if err != nil {
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
} }
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
...@@ -458,9 +477,6 @@ func (s *Server) processBatch() error { ...@@ -458,9 +477,6 @@ func (s *Server) processBatch() error {
return nil return nil
} }
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode batch: %w", err) return fmt.Errorf("failed to decode batch: %w", err)
...@@ -719,12 +735,71 @@ func (s *Server) reserveWorstCaseGraph() error { ...@@ -719,12 +735,71 @@ func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
defer ctx.Close() defer ctx.Close()
var err error
inputs := make([]input.Input, s.batchSize)
mmStore := newMultimodalStore()
// Multimodal strategy:
// - Encode a 2048x2048 image. This assumes that a single image of this
// size is sufficient to trigger the worst case. This is currently true
// because for existing models, only a single image fits in a batch.
// - Add the embedding to a full batch of tokens - this is necessary because
// the model may be looking for non-image data, such as <image> tags.
// - Run PostTokenize to execute any transformations between generated
// embeddings and what the forward pass expects.
// - The result may now be larger than a batch (images may not fit in a
// single batch), so trim based on what will fit and must be grouped together.
// - Fill out the rest of the space with text tokens.
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
mmCtx := s.model.Backend().NewContext()
defer mmCtx.Close()
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
var buf bytes.Buffer
bmp.Encode(&buf, img)
if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
mmStore.addMultimodal(inputs[0].Multimodal)
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return err
}
for i, inp := range inputs {
minBatch := 1 + inp.SameBatch
if minBatch > s.batchSize {
inputs = inputs[i:min(i+minBatch, len(inputs))]
break
} else if i+minBatch > s.batchSize {
inputs = inputs[:i]
break
}
}
if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize)
copy(newInputs, inputs)
inputs = newInputs
}
}
}
var batch input.Batch var batch input.Batch
inputs := make([]int32, s.batchSize) batchInputs := make([]int32, len(inputs))
batch.Positions = make([]int32, len(inputs)) batch.Positions = make([]int32, len(inputs))
batch.Sequences = make([]int, len(inputs)) batch.Sequences = make([]int, len(inputs))
for i := range inputs { for i, inp := range inputs {
batchInputs[i] = inp.Token
if inp.Multimodal != nil {
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
if err != nil {
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
}
batch.Positions[i] = int32(i) batch.Positions[i] = int32(i)
} }
...@@ -733,11 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error { ...@@ -733,11 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Outputs[i] = int32(i) batch.Outputs[i] = int32(i)
} }
var err error batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return err
}
cache := s.model.Config().Cache cache := s.model.Config().Cache
if cache != nil { if cache != nil {
...@@ -752,16 +823,12 @@ func (s *Server) reserveWorstCaseGraph() error { ...@@ -752,16 +823,12 @@ func (s *Server) reserveWorstCaseGraph() error {
return err return err
} }
err = ctx.Forward(t).Reserve() ctx.Forward(t).Reserve()
if err != nil {
return err
}
return nil return nil
} }
func (s *Server) loadModel( func (s *Server) initModel(
ctx context.Context,
mpath string, mpath string,
params ml.BackendParams, params ml.BackendParams,
lpath multiLPath, lpath multiLPath,
...@@ -769,21 +836,21 @@ func (s *Server) loadModel( ...@@ -769,21 +836,21 @@ func (s *Server) loadModel(
kvCacheType string, kvCacheType string,
kvSize int, kvSize int,
multiUserCache bool, multiUserCache bool,
) { ) error {
var err error var err error
s.model, err = model.New(ctx, mpath, params) s.model, err = model.New(mpath, params)
if err != nil { if err != nil {
panic(err) return err
} }
// TODO(jessegross): LoRA loading // TODO(jessegross): LoRA loading
if lpath.String() != "" { if lpath.String() != "" {
panic("loras are not yet implemented") return errors.New("loras are not yet implemented")
} }
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil { if err != nil {
panic(err) return err
} }
if !s.cache.enabled && parallel > 1 { if !s.cache.enabled && parallel > 1 {
...@@ -795,7 +862,30 @@ func (s *Server) loadModel( ...@@ -795,7 +862,30 @@ func (s *Server) loadModel(
s.seqs = make([]*Sequence, s.parallel) s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
err = s.reserveWorstCaseGraph() return s.reserveWorstCaseGraph()
}
func (s *Server) load(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
parallel int,
kvCacheType string,
kvSize int,
multiUserCache bool,
) {
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
if err != nil {
panic(err)
}
slog.Debug("memory", "allocated", s.model.Backend().BackendMemory())
err = s.model.Backend().Load(ctx,
func(progress float32) {
s.progress = progress
})
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -816,7 +906,7 @@ func Execute(args []string) error { ...@@ -816,7 +906,7 @@ func Execute(args []string) error {
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on") port := fs.Int("port", 8080, "Port to expose the server on")
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)") _ = fs.Bool("verbose", false, "verbose output (default: disabled)")
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") _ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
...@@ -831,22 +921,7 @@ func Execute(args []string) error { ...@@ -831,22 +921,7 @@ func Execute(args []string) error {
if err := fs.Parse(args); err != nil { if err := fs.Parse(args); err != nil {
return err return err
} }
level := slog.LevelInfo slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
if *verbose {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
slog.Info("starting ollama engine") slog.Info("starting ollama engine")
server := &Server{ server := &Server{
...@@ -854,9 +929,14 @@ func Execute(args []string) error { ...@@ -854,9 +929,14 @@ func Execute(args []string) error {
status: llm.ServerStatusLoadingModel, status: llm.ServerStatusLoadingModel,
} }
server.cond = sync.NewCond(&server.mu)
server.ready.Add(1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// TODO(jessegross): Parameters that need to be implemented: // TODO(jessegross): Parameters that need to be implemented:
// no-mmap // no-mmap
// mlock
var tensorSplitFloats []float32 var tensorSplitFloats []float32
if *tensorSplit != "" { if *tensorSplit != "" {
...@@ -869,9 +949,6 @@ func Execute(args []string) error { ...@@ -869,9 +949,6 @@ func Execute(args []string) error {
} }
params := ml.BackendParams{ params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads, NumThreads: *threads,
NumGPULayers: *numGPULayers, NumGPULayers: *numGPULayers,
MainGPU: *mainGPU, MainGPU: *mainGPU,
...@@ -879,14 +956,7 @@ func Execute(args []string) error { ...@@ -879,14 +956,7 @@ func Execute(args []string) error {
FlashAttention: *flashAttention, FlashAttention: *flashAttention,
} }
server.ready.Add(1) go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
go server.run(ctx) go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port) addr := "127.0.0.1:" + strconv.Itoa(*port)
......
...@@ -176,7 +176,7 @@ func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSa ...@@ -176,7 +176,7 @@ func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSa
vocabIds[i] = uint32(i) vocabIds[i] = uint32(i)
} }
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)}) grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
if grammar == nil { if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar") return nil, errors.New("sample: failed to initialize grammar")
} }
......
...@@ -27,6 +27,7 @@ function checkEnv() { ...@@ -27,6 +27,7 @@ function checkEnv() {
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0] $env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
} }
# Locate CUDA versions # Locate CUDA versions
# Note: this assumes every version found will be built
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue') $cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
if ($cudaList.length -eq 0) { if ($cudaList.length -eq 0) {
$d=(get-command -ea 'silentlycontinue' nvcc).path $d=(get-command -ea 'silentlycontinue' nvcc).path
...@@ -93,6 +94,19 @@ function buildOllama() { ...@@ -93,6 +94,19 @@ function buildOllama() {
$hashEnv = @{} $hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v11]
write-host "Building CUDA v11 backend libraries"
# Note: cuda v11 requires msvc 2019 so force the older generator
# to avoid 2022 (or newer) from being used as the default
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
if ("$script:CUDA_DIRS".Contains("v12")) { if ("$script:CUDA_DIRS".Contains("v12")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }} $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12] $env:CUDAToolkit_ROOT=$hashEnv[$v12]
......
...@@ -10,7 +10,9 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \ ...@@ -10,7 +10,9 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=GOFLAGS \ --build-arg=GOFLAGS \
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \ --build-arg=OLLAMA_CUSTOM_CPU_DEFS \
--build-arg=OLLAMA_SKIP_CUDA_GENERATE \ --build-arg=OLLAMA_SKIP_CUDA_GENERATE \
--build-arg=OLLAMA_SKIP_CUDA_11_GENERATE \
--build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \ --build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \
--build-arg=CUDA_V11_ARCHITECTURES \
--build-arg=CUDA_V12_ARCHITECTURES \ --build-arg=CUDA_V12_ARCHITECTURES \
--build-arg=OLLAMA_SKIP_ROCM_GENERATE \ --build-arg=OLLAMA_SKIP_ROCM_GENERATE \
--build-arg=OLLAMA_FAST_BUILD \ --build-arg=OLLAMA_FAST_BUILD \
......
...@@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is ...@@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
} }
defer bin.Close() defer bin.Close()
f, _, err := ggml.Decode(bin, 1024) f, err := ggml.Decode(bin, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -430,7 +430,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr ...@@ -430,7 +430,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
fnWrap := func(n uint64) { fnWrap := func(n uint64) {
done := doneBytes.Add(n) done := doneBytes.Add(n)
progress := float32(done) / float32(totalBytes) progress := float32(done) / float32(totalBytes)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0", Total: layer.Size, Completed: int64(progress * float32(layer.Size))}) fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0000000000000000000", Total: layer.Size, Completed: int64(progress * float32(layer.Size))})
} }
ftype, err := ggml.ParseFileType(quantizeType) ftype, err := ggml.ParseFileType(quantizeType)
if err != nil { if err != nil {
...@@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr ...@@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err return nil, err
} }
f, _, err := ggml.Decode(temp, 1024) f, err := ggml.Decode(temp, 1024)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err)) slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err return nil, err
...@@ -501,48 +501,27 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML ...@@ -501,48 +501,27 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return nil, errOnlyGGUFSupported return nil, errOnlyGGUFSupported
} }
stat, err := blob.Stat() f, err := ggml.Decode(blob, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var offset int64 mediatype := "application/vnd.ollama.image.model"
for offset < stat.Size() { if f.KV().Kind() == "adapter" {
f, n, err := ggml.Decode(blob, 1024) mediatype = "application/vnd.ollama.image.adapter"
if errors.Is(err, io.EOF) { } else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" {
break // if a model has vision.block_count but not block_count, it is a standalone vision model
} else if err != nil { mediatype = "application/vnd.ollama.image.projector"
return nil, err }
}
mediatype := "application/vnd.ollama.image.model"
if f.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
mediatype = "application/vnd.ollama.image.projector"
}
var layer Layer
if digest != "" && n == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
}
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
if err != nil {
return nil, err
}
}
layers = append(layers, &layerGGML{layer, f}) layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
offset = n if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
} }
layers = append(layers, &layerGGML{layer, f})
return detectChatTemplate(layers) return detectChatTemplate(layers)
} }
......
...@@ -464,6 +464,10 @@ type downloadOpts struct { ...@@ -464,6 +464,10 @@ type downloadOpts struct {
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) { func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is is empty")
}
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)
if err != nil { if err != nil {
return false, err return false, err
......
...@@ -23,9 +23,10 @@ import ( ...@@ -23,9 +23,10 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
...@@ -37,6 +38,7 @@ var ( ...@@ -37,6 +38,7 @@ var (
errCapabilityInsert = errors.New("insert") errCapabilityInsert = errors.New("insert")
errCapabilityVision = errors.New("vision") errCapabilityVision = errors.New("vision")
errCapabilityEmbedding = errors.New("embedding") errCapabilityEmbedding = errors.New("embedding")
errCapabilityThinking = errors.New("thinking")
errInsecureProtocol = errors.New("insecure protocol http") errInsecureProtocol = errors.New("insecure protocol http")
) )
...@@ -71,22 +73,18 @@ func (m *Model) Capabilities() []model.Capability { ...@@ -71,22 +73,18 @@ func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{} capabilities := []model.Capability{}
// Check for completion capability // Check for completion capability
r, err := os.Open(m.ModelPath) f, err := gguf.Open(m.ModelPath)
if err == nil { if err == nil {
defer r.Close() defer f.Close()
f, _, err := ggml.Decode(r, 1024) if f.KeyValue("pooling_type").Valid() {
if err == nil { capabilities = append(capabilities, model.CapabilityEmbedding)
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
capabilities = append(capabilities, model.CapabilityCompletion)
}
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityVision)
}
} else { } else {
slog.Error("couldn't decode ggml", "error", err) // If no embedding is specified, we assume the model supports completion
capabilities = append(capabilities, model.CapabilityCompletion)
}
if f.KeyValue("vision.block_count").Valid() {
capabilities = append(capabilities, model.CapabilityVision)
} }
} else { } else {
slog.Error("couldn't open model file", "error", err) slog.Error("couldn't open model file", "error", err)
...@@ -111,6 +109,12 @@ func (m *Model) Capabilities() []model.Capability { ...@@ -111,6 +109,12 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityVision) capabilities = append(capabilities, model.CapabilityVision)
} }
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if openingTag != "" && closingTag != "" {
capabilities = append(capabilities, model.CapabilityThinking)
}
return capabilities return capabilities
} }
...@@ -127,6 +131,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error { ...@@ -127,6 +131,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
model.CapabilityInsert: errCapabilityInsert, model.CapabilityInsert: errCapabilityInsert,
model.CapabilityVision: errCapabilityVision, model.CapabilityVision: errCapabilityVision,
model.CapabilityEmbedding: errCapabilityEmbedding, model.CapabilityEmbedding: errCapabilityEmbedding,
model.CapabilityThinking: errCapabilityThinking,
} }
for _, cap := range want { for _, cap := range want {
...@@ -141,11 +146,19 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error { ...@@ -141,11 +146,19 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
} }
} }
var err error
if len(errs) > 0 { if len(errs) > 0 {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) err = fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
} }
return nil if slices.Contains(errs, errCapabilityThinking) {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
// append a message to the existing error
return fmt.Errorf("%w. Pull the model again to get the latest version with full thinking support", err)
}
}
return err
} }
func (m *Model) String() string { func (m *Model) String() string {
......
package server package server
import ( import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
// Constants for GGUF magic bytes and version
var (
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
ggufVer = uint32(3) // Version 3
)
// Helper function to create mock GGUF data
func createMockGGUFData(architecture string, vision bool) []byte {
var buf bytes.Buffer
// Write GGUF header
buf.Write(ggufMagic)
binary.Write(&buf, binary.LittleEndian, ggufVer)
// Write tensor count (0 for our test)
var numTensors uint64 = 0
binary.Write(&buf, binary.LittleEndian, numTensors)
// Calculate number of metadata entries
numMetaEntries := uint64(1) // architecture entry
if vision {
numMetaEntries++
}
// Add embedding entry if architecture is "bert"
if architecture == "bert" {
numMetaEntries++
}
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
// Write architecture metadata
archKey := "general.architecture"
keyLen := uint64(len(archKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(archKey)
// String type (8)
var strType uint32 = 8
binary.Write(&buf, binary.LittleEndian, strType)
// String length
strLen := uint64(len(architecture))
binary.Write(&buf, binary.LittleEndian, strLen)
buf.WriteString(architecture)
if vision {
visionKey := architecture + ".vision.block_count"
keyLen = uint64(len(visionKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(visionKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var countVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, countVal)
}
// Write embedding metadata if architecture is "bert"
if architecture == "bert" {
poolKey := architecture + ".pooling_type"
keyLen = uint64(len(poolKey))
binary.Write(&buf, binary.LittleEndian, keyLen)
buf.WriteString(poolKey)
// uint32 type (4)
var uint32Type uint32 = 4
binary.Write(&buf, binary.LittleEndian, uint32Type)
// uint32 value (1)
var poolingVal uint32 = 1
binary.Write(&buf, binary.LittleEndian, poolingVal)
}
return buf.Bytes()
}
func TestModelCapabilities(t *testing.T) { func TestModelCapabilities(t *testing.T) {
// Create a temporary directory for test files // Create completion model (llama architecture without vision)
tempDir := t.TempDir() completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
// Create different types of mock model files }, []*ggml.Tensor{})
completionModelPath := filepath.Join(tempDir, "model.bin")
visionModelPath := filepath.Join(tempDir, "vision_model.bin") // Create vision model (llama architecture with vision block count)
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") visionModelPath, _ := createBinFile(t, ggml.KV{
// Create a simple model file for tests that don't depend on GGUF content "general.architecture": "llama",
simpleModelPath := filepath.Join(tempDir, "simple_model.bin") "llama.vision.block_count": uint32(1),
}, []*ggml.Tensor{})
if err := errors.Join(
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644), // Create embedding model (bert architecture with pooling type)
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644), embeddingModelPath, _ := createBinFile(t, ggml.KV{
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644), "general.architecture": "bert",
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644), "bert.pooling_type": uint32(1),
); err != nil { }, []*ggml.Tensor{})
t.Fatalf("Failed to create model files: %v", err)
}
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil { if err != nil {
t.Fatalf("Failed to parse template: %v", err) t.Fatalf("Failed to parse template: %v", err)
} }
chatTemplate, err := template.Parse("{{ .prompt }}") chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil { if err != nil {
t.Fatalf("Failed to parse template: %v", err) t.Fatalf("Failed to parse template: %v", err)
} }
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil { if err != nil {
t.Fatalf("Failed to parse template: %v", err) t.Fatalf("Failed to parse template: %v", err)
...@@ -145,21 +64,13 @@ func TestModelCapabilities(t *testing.T) { ...@@ -145,21 +64,13 @@ func TestModelCapabilities(t *testing.T) {
}, },
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert}, expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
}, },
{
name: "model with tools and insert capability",
model: Model{
ModelPath: simpleModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{ {
name: "model with tools capability", name: "model with tools capability",
model: Model{ model: Model{
ModelPath: simpleModelPath, ModelPath: completionModelPath,
Template: toolsTemplate, Template: toolsTemplate,
}, },
expectedCaps: []model.Capability{model.CapabilityTools}, expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
}, },
{ {
name: "model with vision capability", name: "model with vision capability",
...@@ -224,29 +135,33 @@ func TestModelCapabilities(t *testing.T) { ...@@ -224,29 +135,33 @@ func TestModelCapabilities(t *testing.T) {
} }
func TestModelCheckCapabilities(t *testing.T) { func TestModelCheckCapabilities(t *testing.T) {
// Create a temporary directory for test files // Create simple model file for tests that don't depend on GGUF content
tempDir := t.TempDir() completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
visionModelPath := filepath.Join(tempDir, "vision_model.bin") }, []*ggml.Tensor{})
simpleModelPath := filepath.Join(tempDir, "model.bin")
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") // Create vision model (llama architecture with vision block count)
visionModelPath, _ := createBinFile(t, ggml.KV{
if err := errors.Join( "general.architecture": "llama",
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644), "llama.vision.block_count": uint32(1),
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644), }, []*ggml.Tensor{})
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
); err != nil { // Create embedding model (bert architecture with pooling type)
t.Fatalf("Failed to create model files: %v", err) embeddingModelPath, _ := createBinFile(t, ggml.KV{
} "general.architecture": "bert",
"bert.pooling_type": uint32(1),
}, []*ggml.Tensor{})
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil { if err != nil {
t.Fatalf("Failed to parse template: %v", err) t.Fatalf("Failed to parse template: %v", err)
} }
chatTemplate, err := template.Parse("{{ .prompt }}") chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil { if err != nil {
t.Fatalf("Failed to parse template: %v", err) t.Fatalf("Failed to parse template: %v", err)
} }
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil { if err != nil {
t.Fatalf("Failed to parse template: %v", err) t.Fatalf("Failed to parse template: %v", err)
...@@ -261,7 +176,7 @@ func TestModelCheckCapabilities(t *testing.T) { ...@@ -261,7 +176,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{ {
name: "completion model without tools capability", name: "completion model without tools capability",
model: Model{ model: Model{
ModelPath: simpleModelPath, ModelPath: completionModelPath,
Template: chatTemplate, Template: chatTemplate,
}, },
checkCaps: []model.Capability{model.CapabilityTools}, checkCaps: []model.Capability{model.CapabilityTools},
...@@ -270,7 +185,7 @@ func TestModelCheckCapabilities(t *testing.T) { ...@@ -270,7 +185,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{ {
name: "model with all needed capabilities", name: "model with all needed capabilities",
model: Model{ model: Model{
ModelPath: simpleModelPath, ModelPath: completionModelPath,
Template: toolsInsertTemplate, Template: toolsInsertTemplate,
}, },
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
...@@ -278,7 +193,7 @@ func TestModelCheckCapabilities(t *testing.T) { ...@@ -278,7 +193,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{ {
name: "model missing insert capability", name: "model missing insert capability",
model: Model{ model: Model{
ModelPath: simpleModelPath, ModelPath: completionModelPath,
Template: toolsTemplate, Template: toolsTemplate,
}, },
checkCaps: []model.Capability{model.CapabilityInsert}, checkCaps: []model.Capability{model.CapabilityInsert},
...@@ -287,7 +202,7 @@ func TestModelCheckCapabilities(t *testing.T) { ...@@ -287,7 +202,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{ {
name: "model missing vision capability", name: "model missing vision capability",
model: Model{ model: Model{
ModelPath: simpleModelPath, ModelPath: completionModelPath,
Template: toolsTemplate, Template: toolsTemplate,
}, },
checkCaps: []model.Capability{model.CapabilityVision}, checkCaps: []model.Capability{model.CapabilityVision},
...@@ -312,7 +227,7 @@ func TestModelCheckCapabilities(t *testing.T) { ...@@ -312,7 +227,7 @@ func TestModelCheckCapabilities(t *testing.T) {
{ {
name: "unknown capability", name: "unknown capability",
model: Model{ model: Model{
ModelPath: simpleModelPath, ModelPath: completionModelPath,
Template: chatTemplate, Template: chatTemplate,
}, },
checkCaps: []model.Capability{"unknown"}, checkCaps: []model.Capability{"unknown"},
......
...@@ -59,7 +59,7 @@ type DiskCache struct { ...@@ -59,7 +59,7 @@ type DiskCache struct {
testHookBeforeFinalWrite func(f *os.File) testHookBeforeFinalWrite func(f *os.File)
} }
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))). // PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error { func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data))) return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
} }
......
...@@ -10,9 +10,6 @@ import ( ...@@ -10,9 +10,6 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
...@@ -64,7 +61,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe ...@@ -64,7 +61,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
} }
defer blob.Close() defer blob.Close()
f, _, err := ggml.Decode(blob, -1) f, err := ggml.Decode(blob, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) { ...@@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil return "unknown", nil
} }
func parseObjects(s string) []map[string]any {
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
return nil
} else {
offset += int(decoder.InputOffset())
objs = append(objs, obj)
}
}
return objs
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return nil, false
}
templateObjects := parseObjects(b.String())
if len(templateObjects) == 0 {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range templateObjects[0] {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false
}
responseObjects := parseObjects(s)
if len(responseObjects) == 0 {
return nil, false
}
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
}
return all
}
var objs []map[string]any
for _, p := range responseObjects {
objs = append(objs, collect(p)...)
}
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
}
}
return toolCalls, len(toolCalls) > 0
}
package server
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}
func TestParseObjects(t *testing.T) {
tests := []struct {
input string
want []map[string]any
}{
{
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
},
},
{
input: `{"name": "get_current_weather", "arguments": `,
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
got := parseObjects(tc.input)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}
...@@ -116,7 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL { ...@@ -116,7 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL {
func GetManifestPath() (string, error) { func GetManifestPath() (string, error) {
path := filepath.Join(envconfig.Models(), "manifests") path := filepath.Join(envconfig.Models(), "manifests")
if err := os.MkdirAll(path, 0o755); err != nil { if err := os.MkdirAll(path, 0o755); err != nil {
return "", err return "", fmt.Errorf("%w: ensure path elements are traversable", err)
} }
return path, nil return path, nil
...@@ -139,7 +139,7 @@ func GetBlobsPath(digest string) (string, error) { ...@@ -139,7 +139,7 @@ func GetBlobsPath(digest string) (string, error) {
} }
if err := os.MkdirAll(dirPath, 0o755); err != nil { if err := os.MkdirAll(dirPath, 0o755); err != nil {
return "", err return "", fmt.Errorf("%w: ensure path elements are traversable", err)
} }
return path, nil return path, nil
......
...@@ -3,47 +3,32 @@ package server ...@@ -3,47 +3,32 @@ package server
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
type tokenizeFunc func(context.Context, string) ([]int, error) type tokenizeFunc func(context.Context, string) ([]int, error)
var errTooManyImages = errors.New("vision model only supports a single image per message")
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message var system []api.Message
isMllama := checkMllamaModelFamily(m)
var imageNumTokens int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
if isMllama { // Clip images are represented as 768 tokens, each an embedding
// Our mllama implementation packs all of the embeddings into a single token imageNumTokens := 768
imageNumTokens = 1
} else {
// Clip images are represented as 768 tokens, each an embedding
imageNumTokens = 768
}
n := len(msgs) - 1 n := len(msgs) - 1
// in reverse, find all messages that fit into context window // in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- { for i := n; i >= 0; i-- {
if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages
}
// always include the last message // always include the last message
if i == n { if i == n {
continue continue
...@@ -56,8 +41,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ...@@ -56,8 +41,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
} }
thinkVal := false
if think != nil {
thinkVal = *think
}
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err return "", nil, err
} }
...@@ -84,48 +73,17 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ...@@ -84,48 +73,17 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
currMsgIdx := n currMsgIdx := n
for cnt, msg := range msgs[currMsgIdx:] { for cnt, msg := range msgs[currMsgIdx:] {
prefix := "" if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
imgPrompt := "" return "", nil, errors.New("this model only supports one image while more than one image requested")
}
var prefix string
prompt := msg.Content prompt := msg.Content
for _, i := range msg.Images { for _, i := range msg.Images {
var imgData llm.ImageData imgData := llm.ImageData{
ID: len(images),
if isMllama { Data: i,
if len(m.ProjectorPaths) == 0 {
imgData = llm.ImageData{
ID: len(images),
Data: i,
}
} else {
data, opts, err := mllama.Preprocess(bytes.NewReader(i))
if err != nil {
return "", nil, err
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
return "", nil, err
}
ar, ok := opts["aspectRatioIndex"].(int)
if !ok {
return "", nil, fmt.Errorf("missing aspect ratio for image")
}
imgData = llm.ImageData{
ID: len(images),
Data: buf.Bytes(),
AspectRatioID: ar,
}
}
imgPrompt = "<|image|>"
} else {
imgData = llm.ImageData{
ID: len(images),
Data: i,
}
} }
imgTag := fmt.Sprintf("[img-%d]", imgData.ID) imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
...@@ -137,23 +95,18 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ...@@ -137,23 +95,18 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
images = append(images, imgData) images = append(images, imgData)
} }
msgs[currMsgIdx+cnt].Content = prefix + imgPrompt + prompt msgs[currMsgIdx+cnt].Content = prefix + prompt
} }
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil { thinkVal := false
if think != nil {
thinkVal = *think
}
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
return "", nil, err return "", nil, err
} }
return b.String(), images, nil return b.String(), images, nil
} }
func checkMllamaModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "mllama" {
return true
}
}
return false
}
...@@ -2,8 +2,6 @@ package server ...@@ -2,8 +2,6 @@ package server
import ( import (
"bytes" "bytes"
"image"
"image/png"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
...@@ -14,10 +12,9 @@ import ( ...@@ -14,10 +12,9 @@ import (
func TestChatPrompt(t *testing.T) { func TestChatPrompt(t *testing.T) {
type expect struct { type expect struct {
prompt string prompt string
images [][]byte images [][]byte
aspectRatioID int error error
error error
} }
tmpl, err := template.Parse(` tmpl, err := template.Parse(`
...@@ -28,28 +25,6 @@ func TestChatPrompt(t *testing.T) { ...@@ -28,28 +25,6 @@ func TestChatPrompt(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
mllamaModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}, Config: ConfigV2{ModelFamilies: []string{"mllama"}}}
createImg := func(width, height int) ([]byte, error) {
img := image.NewRGBA(image.Rect(0, 0, width, height))
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
imgBuf, err := createImg(5, 5)
if err != nil {
t.Fatal(err)
}
imgBuf2, err := createImg(6, 6)
if err != nil {
t.Fatal(err)
}
cases := []struct { cases := []struct {
name string name string
...@@ -227,97 +202,14 @@ func TestChatPrompt(t *testing.T) { ...@@ -227,97 +202,14 @@ func TestChatPrompt(t *testing.T) {
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")}, images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
}, },
}, },
{
name: "messages with mllama (no images)",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "messages with mllama single prompt",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "[img-0]<|image|>How many hotdogs are in this image? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "messages with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf}},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? [img-0]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "multiple messages with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{imgBuf}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf2}},
},
expect: expect{
prompt: "[img-0]<|image|>You're a test, Harry! I-I'm a what? [img-1]<|image|>A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{imgBuf, imgBuf2},
aspectRatioID: 1,
},
},
{
name: "earlier image with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "How many hotdogs are in this image?", Images: []api.ImageData{imgBuf}},
{Role: "assistant", Content: "There are four hotdogs."},
{Role: "user", Content: "Which ones have mustard?"},
},
expect: expect{
prompt: "[img-0]<|image|>How many hotdogs are in this image? There are four hotdogs. Which ones have mustard? ",
images: [][]byte{imgBuf},
aspectRatioID: 1,
},
},
{
name: "too many images with mllama",
model: mllamaModel,
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{imgBuf, imgBuf}},
},
expect: expect{
error: errTooManyImages,
},
},
} }
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
model := tt.model model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
if tt.error == nil && err != nil { if tt.error == nil && err != nil {
t.Fatal(err) t.Fatal(err)
} else if tt.error != nil && err != tt.error { } else if tt.error != nil && err != tt.error {
...@@ -341,10 +233,6 @@ func TestChatPrompt(t *testing.T) { ...@@ -341,10 +233,6 @@ func TestChatPrompt(t *testing.T) {
if !bytes.Equal(images[i].Data, tt.images[i]) { if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i].Data) t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
} }
} else {
if images[i].AspectRatioID != tt.aspectRatioID {
t.Errorf("expected aspect ratio %d, got %d", tt.aspectRatioID, images[i].AspectRatioID)
}
} }
} }
}) })
......
...@@ -70,23 +70,7 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType ...@@ -70,23 +70,7 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
newType = fsggml.TensorTypeQ6_K newType = fsggml.TensorTypeQ6_K
} }
} else if strings.Contains(name, "attn_v.weight") { } else if strings.Contains(name, "attn_v.weight") {
if ftype == fsggml.FileTypeQ2_K { if (ftype == fsggml.FileTypeQ4_K_M) &&
if kv.GQA() >= 4 {
newType = fsggml.TensorTypeQ4_K
} else {
newType = fsggml.TensorTypeQ3_K
}
} else if ftype == fsggml.FileTypeQ2_K_S && kv.GQA() >= 4 {
newType = fsggml.TensorTypeQ4_K
} else if ftype == fsggml.FileTypeQ3_K_M {
if qs.iAttnV < 2 {
newType = fsggml.TensorTypeQ5_K
} else {
newType = fsggml.TensorTypeQ4_K
}
} else if ftype == fsggml.FileTypeQ3_K_L {
newType = fsggml.TensorTypeQ5_K
} else if (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ5_K_M) &&
useMoreBits(qs.iAttnV, qs.nAttnV) { useMoreBits(qs.iAttnV, qs.nAttnV) {
newType = fsggml.TensorTypeQ6_K newType = fsggml.TensorTypeQ6_K
} else if ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 { } else if ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
...@@ -114,67 +98,52 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType ...@@ -114,67 +98,52 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
} else if strings.Contains(name, "ffn_down") { } else if strings.Contains(name, "ffn_down") {
iLayer := qs.iFfnDown iLayer := qs.iFfnDown
n_layer := qs.nFfnDown n_layer := qs.nFfnDown
if ftype == fsggml.FileTypeQ2_K { if ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ3_K
} else if ftype == fsggml.FileTypeQ2_K_S {
if iLayer < n_layer/8 {
newType = fsggml.TensorTypeQ4_K
}
} else if ftype == fsggml.FileTypeQ3_K_M {
if iLayer < n_layer/16 {
newType = fsggml.TensorTypeQ5_K
} else if useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ4_K
} else {
newType = fsggml.TensorTypeQ3_K
}
} else if ftype == fsggml.FileTypeQ3_K_L {
newType = fsggml.TensorTypeQ5_K
} else if ftype == fsggml.FileTypeQ4_K_M {
if useMoreBits(iLayer, n_layer) { if useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ6_K newType = fsggml.TensorTypeQ6_K
} }
} else if ftype == fsggml.FileTypeQ5_K_M && useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ6_K
} else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 { } else if ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
newType = fsggml.TensorTypeQ5_K newType = fsggml.TensorTypeQ5_K
} }
qs.iFfnDown++ qs.iFfnDown++
} else if strings.Contains(name, "attn_output.weight") { } else if strings.Contains(name, "attn_output.weight") {
if nExperts == 8 { if nExperts == 8 {
if ftype == fsggml.FileTypeQ2_K || ftype == fsggml.FileTypeQ3_K_S || ftype == fsggml.FileTypeQ3_K_M || if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K
}
} else {
if ftype == fsggml.FileTypeQ2_K {
newType = fsggml.TensorTypeQ3_K
} else if ftype == fsggml.FileTypeQ3_K_M {
newType = fsggml.TensorTypeQ4_K
} else if ftype == fsggml.FileTypeQ3_K_L {
newType = fsggml.TensorTypeQ5_K newType = fsggml.TensorTypeQ5_K
} }
} }
} else if strings.Contains(name, "attn_qkv.weight") { } else if strings.Contains(name, "attn_qkv.weight") {
if ftype == fsggml.FileTypeQ3_K_M || ftype == fsggml.FileTypeQ3_K_L { if ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ4_K
} else if ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K newType = fsggml.TensorTypeQ5_K
} else if ftype == fsggml.FileTypeQ5_K_M {
newType = fsggml.TensorTypeQ6_K
} }
} }
if newType.IsQuantized() { if newType.IsQuantized() {
nx := shape[0] nx := shape[0]
ny := uint64(1)
if len(shape) > 1 {
ny = shape[1]
}
qk_k := newType.BlockSize() qk_k := newType.BlockSize()
// Check if first dimension is divisible by block size
if nx%qk_k != 0 { if nx%qk_k != 0 {
slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String())) // Store the original type for logging
newType = fsggml.TensorTypeF16 originalType := newType
// Select appropriate fallback based on original type
switch newType {
case fsggml.TensorTypeQ4_K:
newType = fsggml.TensorTypeQ5_0
case fsggml.TensorTypeQ5_K:
newType = fsggml.TensorTypeQ5_1
case fsggml.TensorTypeQ6_K:
newType = fsggml.TensorTypeQ8_0
}
// Final check - if still incompatible, fall back to F16
if nx%newType.BlockSize() != 0 {
newType = fsggml.TensorTypeF16
}
slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s",
nx, qk_k, originalType.String(), newType.String()))
} }
} }
return newType return newType
......
...@@ -42,71 +42,6 @@ func TestGetTensorNewType(t *testing.T) { ...@@ -42,71 +42,6 @@ func TestGetTensorNewType(t *testing.T) {
ftype: fsggml.FileTypeF32, ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ6_K, expected: fsggml.TensorTypeQ6_K,
}, },
{
name: "attn_v.weight_q4_k",
kv: map[string]any{
"general.architecture": "foo",
"foo.attention.head_count": uint32(4),
"foo.attention.head_count_kv": uint32(1),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_v.weight_q3_k",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "attn_v.weight_q2_k_s_q4_k",
kv: map[string]any{
"general.architecture": "foo",
"foo.attention.head_count": uint32(4),
"foo.attention.head_count_kv": uint32(1),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K_S,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_v.weight_q3_k_m",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_v.weight_q3_k_m_i",
qs: quantizeState{
iAttnV: 2,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_v.weight_q3_k_l",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_L,
expected: fsggml.TensorTypeQ5_K,
},
{ {
name: "attn_v.weight_q4_k_m", name: "attn_v.weight_q4_k_m",
qs: quantizeState{ qs: quantizeState{
...@@ -156,88 +91,6 @@ func TestGetTensorNewType(t *testing.T) { ...@@ -156,88 +91,6 @@ func TestGetTensorNewType(t *testing.T) {
ftype: fsggml.FileTypeF32, ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ8_0, expected: fsggml.TensorTypeQ8_0,
}, },
{
name: "ffn_down_q2_k",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "ffn_down_q2_k_s",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K_S,
expected: fsggml.TensorTypeQ4_0,
},
{
name: "ffn_down_q2_k_s_layers",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K_S,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "ffn_down_q3_k_m_base",
qs: quantizeState{
iFfnDown: 1,
nFfnDown: 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "ffn_down_q3_k_m_16",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 16,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "ffn_down_q3_k_m_8",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "ffn_down_q3_k_l",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_L,
expected: fsggml.TensorTypeQ5_K,
},
{ {
name: "ffn_down_q4_k_m", name: "ffn_down_q4_k_m",
qs: quantizeState{ qs: quantizeState{
...@@ -264,19 +117,6 @@ func TestGetTensorNewType(t *testing.T) { ...@@ -264,19 +117,6 @@ func TestGetTensorNewType(t *testing.T) {
ftype: fsggml.FileTypeQ4_K_M, ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ6_K, expected: fsggml.TensorTypeQ6_K,
}, },
{
name: "ffn_down_q5_k_m",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ5_K_M,
expected: fsggml.TensorTypeQ6_K,
},
{ {
name: "ffn_down_q4_k_s", name: "ffn_down_q4_k_s",
qs: quantizeState{ qs: quantizeState{
...@@ -290,59 +130,6 @@ func TestGetTensorNewType(t *testing.T) { ...@@ -290,59 +130,6 @@ func TestGetTensorNewType(t *testing.T) {
ftype: fsggml.FileTypeQ4_K_S, ftype: fsggml.FileTypeQ4_K_S,
expected: fsggml.TensorTypeQ5_K, expected: fsggml.TensorTypeQ5_K,
}, },
{
name: "attn_output.weight_8_expert",
qs: quantizeState{},
kv: map[string]any{
"general.architecture": "foo",
"foo.expert_count": uint32(8),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_output.weight_q2",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ2_K,
expected: fsggml.TensorTypeQ3_K,
},
{
name: "attn_output.weight_q3_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "attn_output.weight_q3_k_l",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_output.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_L,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_qkv.weight_q3_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_qkv.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ3_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{ {
name: "attn_qkv.weight_q4_k_m", name: "attn_qkv.weight_q4_k_m",
qs: quantizeState{}, qs: quantizeState{},
...@@ -353,16 +140,6 @@ func TestGetTensorNewType(t *testing.T) { ...@@ -353,16 +140,6 @@ func TestGetTensorNewType(t *testing.T) {
ftype: fsggml.FileTypeQ4_K_M, ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ5_K, expected: fsggml.TensorTypeQ5_K,
}, },
{
name: "attn_qkv.weight_q5_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_qkv.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ5_K_M,
expected: fsggml.TensorTypeQ6_K,
},
} }
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
...@@ -480,21 +257,13 @@ func TestQuantizeModel(t *testing.T) { ...@@ -480,21 +257,13 @@ func TestQuantizeModel(t *testing.T) {
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
f, err := os.CreateTemp(t.TempDir(), tt.name) p, _ := createBinFile(t, tt.kv, tt.tensors)
if err != nil { fp, err := os.Open(p)
t.Fatal(err.Error())
}
defer f.Close()
err = fsggml.WriteGGUF(f, tt.kv, tt.tensors)
if err != nil {
t.Fatalf("failed to create initial model: %s", err)
}
fp, err := os.Open(f.Name())
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
defer fp.Close() defer fp.Close()
meta, _, err := fsggml.Decode(fp, -1) meta, err := fsggml.Decode(fp, -1)
if err != nil { if err != nil {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
...@@ -526,7 +295,7 @@ func TestQuantizeModel(t *testing.T) { ...@@ -526,7 +295,7 @@ func TestQuantizeModel(t *testing.T) {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
} }
defer fpNew.Close() defer fpNew.Close()
newMeta, _, err := fsggml.Decode(fpNew, -1) newMeta, err := fsggml.Decode(fpNew, -1)
if err != nil { if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err) t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
} }
......
...@@ -4,10 +4,10 @@ import ( ...@@ -4,10 +4,10 @@ import (
"bytes" "bytes"
"cmp" "cmp"
"context" "context"
"encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"image"
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
...@@ -17,8 +17,6 @@ import ( ...@@ -17,8 +17,6 @@ import (
"net/netip" "net/netip"
"os" "os"
"os/signal" "os/signal"
"path/filepath"
"regexp"
"slices" "slices"
"strings" "strings"
"syscall" "syscall"
...@@ -26,6 +24,7 @@ import ( ...@@ -26,6 +24,7 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/image/webp"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
...@@ -33,11 +32,13 @@ import ( ...@@ -33,11 +32,13 @@ import (
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/tools"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
...@@ -98,6 +99,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C ...@@ -98,6 +99,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, err return nil, nil, nil, err
} }
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
return nil, nil, nil, fmt.Errorf("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
}
if err := model.CheckCapabilities(caps...); err != nil { if err := model.CheckCapabilities(caps...); err != nil {
return nil, nil, nil, fmt.Errorf("%s %w", name, err) return nil, nil, nil, fmt.Errorf("%s %w", name, err)
} }
...@@ -181,6 +186,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -181,6 +186,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" { if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert) caps = append(caps, model.CapabilityInsert)
} }
if req.Think != nil && *req.Think {
caps = append(caps, model.CapabilityThinking)
// TODO(drifkin): consider adding a warning if it's false and the model
// doesn't support thinking. It's not strictly required, but it can be a
// hint that the user is on an older qwen3/r1 model that doesn't have an
// updated template supporting thinking
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
...@@ -204,38 +216,14 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -204,38 +216,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
isMllama := checkMllamaModelFamily(m) if slices.Contains(m.Config.ModelFamilies, "mllama") && len(req.Images) > 1 {
if isMllama && len(req.Images) > 1 { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image while more than one image requested"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
return return
} }
images := make([]llm.ImageData, len(req.Images)) images := make([]llm.ImageData, len(req.Images))
for i := range req.Images { for i := range req.Images {
if isMllama && len(m.ProjectorPaths) > 0 { images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
ar, ok := opts["aspectRatioIndex"].(int)
if !ok {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.LittleEndian, data)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
return
}
images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: ar}
} else {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
} }
prompt := req.Prompt prompt := req.Prompt
...@@ -267,15 +255,15 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -267,15 +255,15 @@ func (s *Server) GenerateHandler(c *gin.Context) {
for _, i := range images { for _, i := range images {
imgPrompt := "" imgPrompt := ""
if isMllama {
imgPrompt = "<|image|>"
}
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)}) msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
} }
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
} }
values.Think = req.Think != nil && *req.Think
values.IsThinkSet = req.Think != nil
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama") slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
...@@ -295,7 +283,14 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -295,7 +283,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt = b.String() prompt = b.String()
} }
slog.Debug("generate request", "images", len(images), "prompt", prompt) var thinkingState *thinking.Parser
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{
OpeningTag: openingTag,
ClosingTag: closingTag,
}
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
...@@ -321,6 +316,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -321,6 +316,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}, },
} }
if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
}
if _, err := sb.WriteString(cr.Content); err != nil { if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
...@@ -348,11 +349,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -348,11 +349,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse var r api.GenerateResponse
var sb strings.Builder var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch { for rr := range ch {
switch t := rr.(type) { switch t := rr.(type) {
case api.GenerateResponse: case api.GenerateResponse:
sb.WriteString(t.Response) sbThinking.WriteString(t.Thinking)
sbContent.WriteString(t.Response)
r = t r = t
case gin.H: case gin.H:
msg, ok := t["error"].(string) msg, ok := t["error"].(string)
...@@ -368,7 +371,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -368,7 +371,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
} }
r.Response = sb.String() r.Thinking = sbThinking.String()
r.Response = sbContent.String()
c.JSON(http.StatusOK, r) c.JSON(http.StatusOK, r)
return return
} }
...@@ -1226,26 +1231,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { ...@@ -1226,26 +1231,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
} }
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
level := slog.LevelInfo slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
if envconfig.Debug() {
level = slog.LevelDebug
}
slog.Info("server config", "env", envconfig.Values()) slog.Info("server config", "env", envconfig.Values())
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
blobsDir, err := GetBlobsPath("") blobsDir, err := GetBlobsPath("")
if err != nil { if err != nil {
...@@ -1324,6 +1311,10 @@ func Serve(ln net.Listener) error { ...@@ -1324,6 +1311,10 @@ func Serve(ln net.Listener) error {
s.sched.Run(schedCtx) s.sched.Run(schedCtx)
// register the experimental webp decoder
// so webp images can be used in multimodal inputs
image.RegisterFormat("webp", "RIFF????WEBP", webp.Decode, webp.DecodeConfig)
// At startup we retrieve GPU information so we can get log messages before loading a model // At startup we retrieve GPU information so we can get log messages before loading a model
// This will log warnings to the log in case we have problems with detected GPUs // This will log warnings to the log in case we have problems with detected GPUs
gpus := discover.GetGPUInfo() gpus := discover.GetGPUInfo()
...@@ -1341,31 +1332,29 @@ func Serve(ln net.Listener) error { ...@@ -1341,31 +1332,29 @@ func Serve(ln net.Listener) error {
func waitForStream(c *gin.Context, ch chan any) { func waitForStream(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
var latest api.ProgressResponse
for resp := range ch { for resp := range ch {
switch r := resp.(type) { switch r := resp.(type) {
case api.ProgressResponse: case api.ProgressResponse:
if r.Status == "success" { latest = r
c.JSON(http.StatusOK, r)
return
}
case gin.H: case gin.H:
status, ok := r["status"].(int) status, ok := r["status"].(int)
if !ok { if !ok {
status = http.StatusInternalServerError status = http.StatusInternalServerError
} }
if errorMsg, ok := r["error"].(string); ok { errorMsg, ok := r["error"].(string)
c.JSON(status, gin.H{"error": errorMsg}) if !ok {
return errorMsg = "unknown error"
} else {
c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
return
} }
c.JSON(status, gin.H{"error": errorMsg})
return
default: default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unknown message type"})
return return
} }
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
c.JSON(http.StatusOK, latest)
} }
func streamResponse(c *gin.Context, ch chan any) { func streamResponse(c *gin.Context, ch chan any) {
...@@ -1476,6 +1465,9 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1476,6 +1465,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools) caps = append(caps, model.CapabilityTools)
} }
if req.Think != nil && *req.Think {
caps = append(caps, model.CapabilityThinking)
}
name := model.ParseName(req.Model) name := model.ParseName(req.Model)
if !name.IsValid() { if !name.IsValid() {
...@@ -1516,20 +1508,31 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1516,20 +1508,31 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err) slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
slog.Debug("chat request", "images", len(images), "prompt", prompt) var thinkingState *thinking.Parser
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{
OpeningTag: openingTag,
ClosingTag: closingTag,
}
}
var toolParser *tools.Parser
if len(req.Tools) > 0 {
toolParser = tools.NewParser(m.Template.Template, req.Tools)
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
...@@ -1549,43 +1552,41 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1549,43 +1552,41 @@ func (s *Server) ChatHandler(c *gin.Context) {
}, },
} }
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
if r.Done { if r.Done {
res.DoneReason = r.DoneReason.String() res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
// TODO: tool call checking and filtering should be moved outside of this callback once streaming if len(req.Tools) > 0 {
// however this was a simple change for now without reworking streaming logic of this (and other) toolCalls, content := toolParser.Add(res.Message.Content)
// handlers if len(content) > 0 {
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 { res.Message.Content = content
ch <- res } else if len(toolCalls) > 0 {
return res.Message.ToolCalls = toolCalls
} res.Message.Content = ""
} else if res.Message.Thinking != "" {
// Streaming tool calls: // don't return
// If tools are recognized, use a flag to track the sending of a tool downstream } else {
// This ensures that content is cleared from the message on the last chunk sent if r.Done {
sb.WriteString(r.Content) res.Message.Content = toolParser.Content()
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { ch <- res
res.Message.ToolCalls = toolCalls }
for i := range toolCalls { return
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
} }
res.Message.Content = ""
sb.Reset()
ch <- res
return
} }
if r.Done { ch <- res
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
res.Message.Content = sb.String()
}
ch <- res
}
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
...@@ -1593,12 +1594,18 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1593,12 +1594,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse var resp api.ChatResponse
var sb strings.Builder var toolCalls []api.ToolCall
var sbThinking strings.Builder
var sbContent strings.Builder
for rr := range ch { for rr := range ch {
switch t := rr.(type) { switch t := rr.(type) {
case api.ChatResponse: case api.ChatResponse:
sb.WriteString(t.Message.Content) sbThinking.WriteString(t.Message.Thinking)
sbContent.WriteString(t.Message.Content)
resp = t resp = t
if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...)
}
case gin.H: case gin.H:
msg, ok := t["error"].(string) msg, ok := t["error"].(string)
if !ok { if !ok {
...@@ -1613,13 +1620,11 @@ func (s *Server) ChatHandler(c *gin.Context) { ...@@ -1613,13 +1620,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} }
resp.Message.Content = sb.String() resp.Message.Content = sbContent.String()
resp.Message.Thinking = sbThinking.String()
if len(req.Tools) > 0 { if len(toolCalls) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok { resp.Message.ToolCalls = toolCalls
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
...@@ -1644,8 +1649,6 @@ func handleScheduleError(c *gin.Context, name string, err error) { ...@@ -1644,8 +1649,6 @@ func handleScheduleError(c *gin.Context, name string, err error) {
} }
} }
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
func filterThinkTags(msgs []api.Message, m *Model) []api.Message { func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" { if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
finalUserIndex := -1 finalUserIndex := -1
...@@ -1657,7 +1660,17 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { ...@@ -1657,7 +1660,17 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
for i, msg := range msgs { for i, msg := range msgs {
if msg.Role == "assistant" && i < finalUserIndex { if msg.Role == "assistant" && i < finalUserIndex {
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "") // TODO(drifkin): this is from before we added proper thinking support.
// However, even if thinking is not enabled (and therefore we shouldn't
// change the user output), we should probably perform this filtering
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
// to save tokens and improve quality.
thinkingState := &thinking.Parser{
OpeningTag: "<think>",
ClosingTag: "</think>",
}
_, content := thinkingState.AddContent(msg.Content)
msgs[i].Content = content
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment