Commit ed443a03 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

Runner for Ollama engine

This provides integration with the new Ollama engine
(58245413 next ollama runner (#7913)) and the rest of the Ollama
infrastructure such as the runner and Ollama server.

In addition, it also builds out the KV cache infrastructure to
support requirements of how Ollama runs models such as:
 - Parallel processing
 - Memory management for defragmentation and shifting
 - Multi-modal modals

Both old and new engines continue to be supported. By default, only
the old engine is used. To enable the new engine:

Start the server with the OLLAMA_NEW_ENGINE environment variable set:
OLLAMA_NEW_ENGINE=1 ./ollama serve

Start a model that is supported by the Ollama engine. This one is Llama 3.1 8b Q4_K_M:
./ollama run jessegross/llama3.1
parent 6945617a
package runner package llamarunner
import ( import (
"errors" "errors"
......
package runner package llamarunner
import ( import (
"testing" "testing"
......
package runner package llamarunner
import ( import (
"errors" "errors"
......
package runner package llamarunner
import ( import (
"reflect" "reflect"
......
package runner package llamarunner
import ( import (
"context" "context"
...@@ -24,6 +24,7 @@ import ( ...@@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/runner/common"
) )
// input is an element of the prompt to process, either // input is an element of the prompt to process, either
...@@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := findStop(sequence, seq.stop); ok { if ok, stop := common.FindStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
var tokenTruncated bool var tokenTruncated bool
origLen := len(seq.pendingResponses) origLen := len(seq.pendingResponses)
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop)
newLen := len(seq.pendingResponses) newLen := len(seq.pendingResponses)
// Update the cache based on the tokens that will be returned: // Update the cache based on the tokens that will be returned:
...@@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) ...@@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue continue
} }
if containsStopSuffix(sequence, seq.stop) { if common.ContainsStopSuffix(sequence, seq.stop) {
continue continue
} }
if incompleteUnicode(sequence) { if common.IncompleteUnicode(sequence) {
continue continue
} }
...@@ -885,9 +886,6 @@ func (s *Server) loadModel( ...@@ -885,9 +886,6 @@ func (s *Server) loadModel(
} }
func Execute(args []string) error { func Execute(args []string) error {
if args[0] == "runner" {
args = args[1:]
}
fs := flag.NewFlagSet("runner", flag.ExitOnError) fs := flag.NewFlagSet("runner", flag.ExitOnError)
mpath := fs.String("model", "", "Path to model binary file") mpath := fs.String("model", "", "Path to model binary file")
ppath := fs.String("mmproj", "", "Path to projector binary file") ppath := fs.String("mmproj", "", "Path to projector binary file")
......
package ollamarunner
import (
"errors"
"fmt"
"log/slog"
"math"
"reflect"
"time"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
)
type InputCache struct {
// context window size (per slot)
numCtx int32
// does the cache store data or do we need to always send the full input?
// note that when enabled is false the underlying cache may either be nil
// or a non-nil dummy that doesn't actually store anything
enabled bool
// individual KV caches
slots []InputCacheSlot
// optimize cache eviction for multiple users
multiUserCache bool
cache kvcache.Cache
}
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
for i := range slots {
slots[i] = InputCacheSlot{
Id: i,
Inputs: make([]input, 0),
}
}
cache := model.Config().Cache
if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
}
return &InputCache{
numCtx: kvSize / int32(numSlots),
enabled: cache != nil,
slots: slots,
multiUserCache: multiUserCache,
cache: cache,
}, nil
}
func kvCacheTypeFromStr(s string) ml.DType {
switch s {
case "q8_0":
panic("kv cache quantization not yet implemented")
case "q4_0":
panic("kv cache quantization not yet implemented")
default:
return ml.DTypeF16
}
}
func (c *InputCache) Close() {
c.cache.Close()
}
// Locking: Operations on InputCacheSlot (including finding one
// through LoadCacheSlot) require a lock to be be held that serializes
// these operations with each other and processBatch
type InputCacheSlot struct {
// Index in the KV cache
Id int
// Inputs that are stored in the KV cache
Inputs []input
// is this cache actively being processed as part of a sequence?
InUse bool
// last time this cache was used (as of start of processing)
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
// In single-user scenarios, the longest cache slot works fine for getting good input
// cache hit rates and it keeps the footprint of the cache small, which improves throughput.
// For multiple users, the "best" cache slot produces better input cache hit rates
// at the cost of worse performance when we miss the input cache.
if !c.multiUserCache {
slot, numPast, err = c.findLongestCacheSlot(prompt)
} else {
slot, numPast, err = c.findBestCacheSlot(prompt)
}
if err != nil {
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()
if numPast == int32(len(prompt)) {
// Leave one input to sample so we can get a response
numPast--
}
if c.cache != nil {
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
if err != nil {
// Some models don't support partial erasure
err = c.cache.Remove(slot.Id, 0, math.MaxInt32)
if err != nil {
return nil, nil, err
}
numPast = 0
}
}
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
"used", numPast, "remaining", int32(len(prompt))-numPast)
prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
for i, s := range c.slots {
if s.InUse {
continue
}
count := countCommonPrefix(s.Inputs, prompt)
if count > longest {
longest = count
longestSlot = &c.slots[i]
}
}
if longestSlot == nil {
return nil, 0, errors.New("no available cache slots")
}
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
longest := int32(-1)
var longestSlot *InputCacheSlot
for i, s := range c.slots {
count := countCommonPrefix(s.Inputs, prompt)
if count > longest {
longest = count
longestSlot = &c.slots[i]
}
if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
oldest = s.lastUsed
oldestSlot = &c.slots[i]
}
}
if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse {
return longestSlot, longest, nil
}
if oldestSlot.InUse {
return nil, 0, errors.New("no available cache slots")
}
if len(oldestSlot.Inputs) != 0 {
slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
"used", oldestSlot.lastUsed)
}
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
}
}
return oldestSlot, longest, nil
}
func countCommonPrefix(a []input, b []input) int32 {
var count int32
for i := range a {
if i >= len(b) {
break
}
if !reflect.DeepEqual(a[i], b[i]) {
break
}
count++
}
return count
}
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 {
discard = 0
}
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
if numKeep >= c.numCtx {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
inputLen := int32(len(slot.Inputs))
discard := c.ShiftDiscard(inputLen, numKeep)
if discard <= 0 {
return nil
}
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if c.cache != nil {
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
if err != nil {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
}
}
for i := numKeep + discard; i < inputLen; i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:inputLen-discard]
return nil
}
package ollamarunner
import (
"image"
"testing"
"time"
)
func TestCountCommon(t *testing.T) {
imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
tests := []struct {
name string
t1 []input
t2 []input
expected int32
}{
{
name: "Equal",
t1: []input{{token: 1}, {token: 2}, {token: 3}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []input{{token: 1}},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 1,
},
{
name: "Image Prefix",
t1: []input{{image: imgA}},
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
expected: 1,
},
{
name: "Mixed",
t1: []input{{token: 1}, {image: imgA}},
t2: []input{{token: 1}, {image: imgA}, {token: 5}},
expected: 2,
},
{
name: "Empty",
t1: []input{},
t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []input{},
t2: []input{},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := countCommonPrefix(tt.t1, tt.t2)
if result != tt.expected {
t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
}
})
}
}
func TestFindCacheSlot(t *testing.T) {
type expected struct {
result int
len int32
}
tests := []struct {
name string
cache InputCache
prompt []input
longest expected
best expected
}{
{
name: "Empty",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
{
name: "Extend",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
{
name: "New",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input{{token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
{
name: "Fork",
cache: InputCache{
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []input{{token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
{
name: "Evict",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}, {token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 2}, {token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
{
name: "In use",
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input{{token: 1}, {token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input{{token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input{{token: 1}, {token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},
}
for _, tt := range tests {
t.Run("Longest-"+tt.name, func(t *testing.T) {
result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
if err != nil {
t.Errorf("findLongestCacheSlot: err %v", err)
} else if result.Id != tt.longest.result || resultLen != tt.longest.len {
t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
result.Id, tt.longest.result, resultLen, tt.longest.len)
}
})
}
for _, tt := range tests {
t.Run("Best-"+tt.name, func(t *testing.T) {
result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
if err != nil {
t.Errorf("findBestCacheSlot: err %v", err)
} else if result.Id != tt.best.result || resultLen != tt.best.len {
t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
result.Id, tt.best.result, resultLen, tt.best.len)
}
})
}
}
func TestShiftDiscard(t *testing.T) {
tests := []struct {
name string
numCtx int32
numKeep int32
inputLen int32
expected int32
}{
{
name: "Shift",
numCtx: 2048,
numKeep: 5,
inputLen: 2048,
expected: 1021,
},
{
name: "Max Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 2048,
expected: 1,
},
{
name: "No Keep",
numCtx: 2048,
numKeep: 0,
inputLen: 2048,
expected: 1024,
},
{
name: "Truncate",
numCtx: 2048,
numKeep: 5,
inputLen: 5000,
expected: 3973,
},
{
name: "Truncate Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 5000,
expected: 2953,
},
{
name: "No Op",
numCtx: 2048,
numKeep: 5,
inputLen: 512,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := InputCache{numCtx: tt.numCtx}
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
if result != tt.expected {
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
}
})
}
}
This diff is collapsed.
package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
)
func Execute(args []string) error {
if args[0] == "runner" {
args = args[1:]
}
var newRunner bool
if args[0] == "--ollama-engine" {
args = args[1:]
newRunner = true
}
if newRunner {
return ollamarunner.Execute(args)
} else {
return llamarunner.Execute(args)
}
}
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
...@@ -92,26 +93,33 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. ...@@ -92,26 +93,33 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var imgData llm.ImageData var imgData llm.ImageData
if isMllama { if isMllama {
data, opts, err := mllama.Preprocess(bytes.NewReader(i)) if envconfig.NewEngine() {
if err != nil { imgData = llm.ImageData{
return "", nil, err ID: len(images),
} Data: i,
}
buf := new(bytes.Buffer) } else {
err = binary.Write(buf, binary.LittleEndian, data) data, opts, err := mllama.Preprocess(bytes.NewReader(i))
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
ar, ok := opts["aspectRatioIndex"].(int) buf := new(bytes.Buffer)
if !ok { err = binary.Write(buf, binary.LittleEndian, data)
return "", nil, fmt.Errorf("missing aspect ratio for image") if err != nil {
} return "", nil, err
}
imgData = llm.ImageData{
ID: len(images), ar, ok := opts["aspectRatioIndex"].(int)
Data: buf.Bytes(), if !ok {
AspectRatioID: ar, return "", nil, fmt.Errorf("missing aspect ratio for image")
}
imgData = llm.ImageData{
ID: len(images),
Data: buf.Bytes(),
AspectRatioID: ar,
}
} }
imgPrompt = "<|image|>" imgPrompt = "<|image|>"
} else { } else {
......
...@@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { ...@@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
images := make([]llm.ImageData, len(req.Images)) images := make([]llm.ImageData, len(req.Images))
for i := range req.Images { for i := range req.Images {
if isMllama { if isMllama && !envconfig.NewEngine() {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment