Commit ed443a03 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

Runner for Ollama engine

This provides integration with the new Ollama engine
(58245413 next ollama runner (#7913)) and the rest of the Ollama
infrastructure such as the runner and Ollama server.

In addition, it also builds out the KV cache infrastructure to
support requirements of how Ollama runs models such as:
 - Parallel processing
 - Memory management for defragmentation and shifting
 - Multi-modal modals

Both old and new engines continue to be supported. By default, only
the old engine is used. To enable the new engine:

Start the server with the OLLAMA_NEW_ENGINE environment variable set:
OLLAMA_NEW_ENGINE=1 ./ollama serve

Start a model that is supported by the Ollama engine. This one is Llama 3.1 8b Q4_K_M:
./ollama run jessegross/llama3.1
parent 6945617a
package cache
import (
"github.com/ollama/ollama/ml"
)
type Options struct {
Position int
}
type Cache interface {
Sub(i int) Cache
Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor)
}
type Simple struct {
DType ml.DType
Capacity int
keys, values []ml.Tensor
}
func (c *Simple) Sub(i int) Cache {
if i >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...)
}
return &Simple{
keys: c.keys[i : i+1],
values: c.values[i : i+1],
Capacity: c.Capacity,
DType: c.DType,
}
}
func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
if c.keys[0] == nil || c.values[0] == nil {
c.keys[0] = ctx.Zeros(c.DType, key.Dim(0)*key.Dim(1)*c.Capacity)
c.values[0] = ctx.Zeros(c.DType, value.Dim(0)*value.Dim(1)*c.Capacity)
}
ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, key.Stride(2)*opts.Position, key.Dim(0)*key.Dim(1)*key.Dim(2))))
ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, value.Stride(2)*opts.Position, value.Dim(0)*value.Dim(1)*value.Dim(2))))
n := min(c.Capacity, key.Dim(2)+opts.Position)
key = c.keys[0].View(ctx, 0,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
n,
)
value = c.values[0].View(ctx, 0,
value.Dim(0), value.Stride(1),
value.Dim(1), value.Stride(2),
n,
)
// TODO shift context if necessary
return key, value
}
...@@ -35,9 +35,9 @@ import ( ...@@ -35,9 +35,9 @@ import (
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llama/runner"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
...@@ -338,7 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { ...@@ -338,7 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
opts.MultiModal = len(info.ProjectorInfo) != 0 // TODO(jessegross): We should either find another way to know if this is
// a vision model or remove the logic. Also consider that other modalities will
// need different behavior anyways.
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
if interactive { if interactive {
......
...@@ -4,7 +4,7 @@ import ( ...@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/ollama/ollama/llama/runner" "github.com/ollama/ollama/runner"
) )
func main() { func main() {
......
...@@ -165,6 +165,8 @@ var ( ...@@ -165,6 +165,8 @@ var (
IntelGPU = Bool("OLLAMA_INTEL_GPU") IntelGPU = Bool("OLLAMA_INTEL_GPU")
// MultiUserCache optimizes prompt caching for multi-user scenarios // MultiUserCache optimizes prompt caching for multi-user scenarios
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
) )
func String(s string) func() string { func String(s string) func() string {
...@@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar { ...@@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational // Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
......
package kvcache
import (
"errors"
"github.com/ollama/ollama/ml"
)
var (
ErrKvCacheFull = errors.New("could not find a kv cache slot")
ErrNotSupported = errors.New("model does not support operation")
)
type Cache interface {
// ** used by model implementations **
// SetLayer sets the active layer of the cache
SetLayer(layer int)
// Get returns the history of key and value tensors plus a mask
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
// Put stores a batch of key and value in the cache
//
// The shape of the tensors is documented in the specific
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// ** cache management **
// Init sets up runtime parameters
Init(backend ml.Backend, dtype ml.DType, capacity int32)
// Close closes the cache and frees resources associated with it
Close()
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
StartForward(ctx ml.Context, positions []int32, seqs []int) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
//
// If an error occurs, the entire context for the sequence should be
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}
package kvcache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
"github.com/ollama/ollama/ml"
)
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
// Causal cache stores K and V tensors according to their position in the
// sequence. Returns the history and a mask for attending to past tokens
//
// The tensors are of shape embed dim, kv heads, batch size
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
Capacity int32
windowSize int32
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch
curBatchSize int
// mask of the cache as used by this batch
curMask ml.Tensor
// locations in the cache that are needed for this batch
curCellRange cellRange
// ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences
// that reference the data there
cells []cacheCell
// maps from sequence to the range of locations where it is stored in the cache
cellRanges map[int]cellRange
// ** cache data storage **
shiftFn shiftFn
backend ml.Backend
cacheCtx ml.Context
keys, values []ml.Tensor
}
type cacheCell struct {
pos int32
sequences []int
}
type cellRange struct {
min int
max int
}
func NewCausalCache(shift shiftFn) *Causal {
return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
}
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{windowSize: windowSize, shiftFn: shift}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.DType = dtype
c.Capacity = capacity
c.cells = make([]cacheCell, capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
c.cacheCtx = backend.NewContext()
}
func (c *Causal) Close() {
c.cacheCtx.Close()
}
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
c.curBatchSize = len(positions)
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
c.curCellRange = newRange()
for i, pos := range positions {
seq := seqs[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
c.curMask, err = c.buildMask(ctx, positions, seqs)
return err
}
func newRange() cellRange {
return cellRange{
min: math.MaxInt,
max: 0,
}
}
// Find the first contiguous block of at least curBatchSize
func (c *Causal) findStartLoc() (int, error) {
var start, count int
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
count++
if count >= c.curBatchSize {
return start, nil
}
} else {
start = i + 1
count = 0
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This does not do padding, which is required for flash attention
len := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*len)
for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
c.cells[j].pos < positions[i]-c.windowSize {
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
}
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
}
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
for _, obj := range objs {
if obj == nil {
continue
}
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
ctx.Forward(srcView.Copy(ctx, dstView))
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v).
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := ctx.MaxTensors() / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
}
func (c *Causal) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
c.curMask.Dim(0),
)
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
value.Dim(0), value.Stride(1),
value.Dim(1), value.Stride(2),
c.curMask.Dim(0),
)
return key, value, c.curMask
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
if c.curBatchSize != key.Dim(2) {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
seqRange := newRange()
for i := range c.cells {
// Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
if slices.Contains(c.cells[i].sequences, dstSeq) {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
}
if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[dstSeq] = seqRange
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
if c.shiftFn == nil {
return ErrNotSupported
}
ctx := c.backend.NewContext()
defer ctx.Close()
seqRange := c.cellRanges[seq]
size := seqRange.max - seqRange.min + 1
offsets := make([]int32, size)
for i := range offsets {
cell := c.cells[seqRange.min+i]
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset
}
}
kShift, err := ctx.FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
for i, key := range c.keys {
if key == nil {
continue
}
key = key.View(ctx, key.Stride(2)*seqRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
size,
)
roped, err := c.shiftFn(ctx, i, key, kShift)
if err != nil {
return err
}
ctx.Forward(roped.Copy(ctx, key))
}
ctx.Compute()
return nil
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
var offset int32
if endIndex != math.MaxInt32 {
offset = beginIndex - endIndex
}
seqRange := newRange()
for i := range c.cells {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
if c.cells[i].pos >= endIndex {
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
// TODO(jessegross): Need to be careful about data shared between sequences
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
}
c.cells[i].pos += offset
}
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
}
if seqRange == newRange() {
delete(c.cellRanges, seq)
return nil
}
c.cellRanges[seq] = seqRange
if endIndex != math.MaxInt32 {
err := c.shift(seq, endIndex+offset, offset)
if err != nil {
return err
}
}
return nil
}
package kvcache
import (
"math"
"slices"
"testing"
"github.com/ollama/ollama/ml"
)
type testCase struct {
name string
in []float32
inShape []int
seqs []int
pos []int32
expected []float32
expectedShape []int
expectedMask []float32
}
func TestStore(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
inShape: []int{2, 3, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
expectedShape: []int{2, 3, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
{
name: "SecondBatch",
in: []float32{115, 215, 125, 225, 135, 235},
inShape: []int{2, 3, 1},
seqs: []int{0},
pos: []int32{4},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
expectedShape: []int{2, 3, 5},
expectedMask: []float32{0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
}
func TestSWA(t *testing.T) {
backend := &testBackend{}
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF32, 16)
tests := []testCase{
{
name: "SlidingWindow",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
testCache(t, backend, cache, tests)
}
func TestSequences(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{2, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
}
func TestRemove(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
testCache(t, backend, cache, tests)
err := cache.Remove(0, 1, math.MaxInt32)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "RemoveEnd",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 1},
pos: []int32{1, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
err = cache.Remove(0, 0, 1)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "RemoveMiddle",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{1, 2},
expected: []float32{7, 8, 3, 4, 4},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
},
}
testCache(t, backend, cache, tests)
}
func TestDefrag(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
inShape: []int{1, 1, 16},
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
err := cache.Remove(0, 2, 4)
if err != nil {
panic(err)
}
err = cache.Remove(0, 13, math.MaxInt32)
if err != nil {
panic(err)
}
tests = []testCase{
{
name: "Defrag",
in: []float32{17, 18, 19},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{16, 17, 18},
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
expectedShape: []int{1, 1, 16},
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
}
func TestCopy(t *testing.T) {
backend := &testBackend{}
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
}
testCache(t, backend, cache, tests)
cache.CopyPrefix(0, 1, 2)
tests = []testCase{
{
name: "Copy",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{3, 4},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
testCache(t, backend, cache, tests)
}
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, test.pos, test.seqs)
if err != nil {
panic(err)
}
cache.SetLayer(0)
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context)
context.Forward(out)
context.Forward(mask)
context.Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
}
})
}
}
type testBackend struct{}
func (b *testBackend) Config() ml.Config {
panic("not implemented")
}
func (b *testBackend) Get(name string) ml.Tensor {
panic("not implemented")
}
func (b *testBackend) NewContext() ml.Context {
return &testContext{}
}
type testContext struct{}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
if len(shape) > 0 {
total = 1
for _, s := range shape {
total *= s
}
}
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
return t, nil
}
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
f := make([]float32, len(s))
for i := range f {
f[i] = float32(s[i])
}
out, _ := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32
return out, nil
}
func (c *testContext) Forward(ml.Tensor) {}
func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) MaxTensors() int {
return 10
}
func (c *testContext) Close() {}
type testTensor struct {
dtype ml.DType
elementSize int
data []float32
shape []int
}
func (t *testTensor) Dim(n int) int {
return t.shape[n]
}
func (t *testTensor) Stride(n int) int {
stride := t.elementSize
for i := range n {
stride *= t.shape[i]
}
return stride
}
func (t *testTensor) Shape() []int {
return t.shape
}
func (t *testTensor) DType() ml.DType {
return t.dtype
}
func (t *testTensor) Bytes() []byte {
panic("not implemented")
}
func (t *testTensor) Floats() []float32 {
out := make([]float32, len(t.data))
copy(out, t.data)
return out
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
}
return out
}
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
offset /= t.elementSize
var s []int
switch len(shape) {
case 1:
s = []int{shape[0]}
case 5:
s = []int{shape[0], shape[2], shape[4]}
default:
panic("unsupported number of dimensions")
}
context := &testContext{}
view := context.Zeros(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)]
return view
}
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil
}
package kvcache
import (
"github.com/ollama/ollama/ml"
)
// Encoder cache stores K and V tensors that are position independent
//
// The tensors can be of any shape and will be returned as they were stored
// The mask is currently always nil
//
// Not currently safe for multiple sequences
type EncoderCache struct {
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// if something is stored during this pass, this
// will be the position (but there is no guarantee
// anything will be stored)
curPos int32
// ** cache metadata **
// was something stored in the cache?
encoderCached bool
// position of the cached data
encoderPos int32
// ** cache data storage **
cacheCtx ml.Context
keys, values []ml.Tensor
}
func NewEncoderCache() *EncoderCache {
return &EncoderCache{}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
c.cacheCtx = backend.NewContext()
}
func (c *EncoderCache) Close() {
c.cacheCtx.Close()
}
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
// The image is always in the first position
c.curPos = positions[0]
return nil
}
func (c *EncoderCache) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer
}
func (c *EncoderCache) EncoderCached() bool {
return c.encoderCached
}
func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.keys[c.curLayer], c.values[c.curLayer], nil
}
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos
c.encoderCached = true
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
}
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
panic("encoder cache does not support multiple sequences")
}
func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
c.encoderCached = false
}
return nil
}
package kvcache
import (
"math"
"github.com/ollama/ollama/ml"
)
// Wrapper cache is a container for multiple types of caches,
// such as for the encoding and decoding portions of a model.
type WrapperCache struct {
// caches we are wrapping
caches []Cache
// cache to be used for this layer
curType int
}
func NewWrapperCache(caches ...Cache) *WrapperCache {
return &WrapperCache{
caches: caches,
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
for _, cache := range c.caches {
cache.Init(backend, dtype, capacity)
}
}
func (c *WrapperCache) Close() {
for _, cache := range c.caches {
cache.Close()
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, positions, seqs)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range positions {
_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
}
}
return err
}
}
c.curType = 0
return nil
}
func (c *WrapperCache) SetLayer(layer int) {
for _, cache := range c.caches {
cache.SetLayer(layer)
}
}
func (c *WrapperCache) SetLayerType(layerType int) {
c.curType = layerType
}
func (c *WrapperCache) UnderlyingCache() Cache {
return c.caches[c.curType]
}
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.caches[c.curType].Get(ctx)
}
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.caches[c.curType].Put(ctx, key, value)
}
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
for _, cache := range c.caches {
cache.CopyPrefix(srcSeq, dstSeq, len)
}
}
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
for _, cache := range c.caches {
err := cache.Remove(seq, beginIndex, endIndex)
if err != nil {
return err
}
}
return nil
}
...@@ -275,6 +275,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt ...@@ -275,6 +275,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
} }
finalParams := []string{"runner"} finalParams := []string{"runner"}
if envconfig.NewEngine() {
finalParams = append(finalParams, "--ollama-engine")
}
finalParams = append(finalParams, params...) finalParams = append(finalParams, params...)
finalParams = append(finalParams, "--port", strconv.Itoa(port)) finalParams = append(finalParams, "--port", strconv.Itoa(port))
......
...@@ -50,6 +50,7 @@ type Context interface { ...@@ -50,6 +50,7 @@ type Context interface {
Forward(Tensor) Forward(Tensor)
Compute(...Tensor) Compute(...Tensor)
MaxTensors() int
Close() Close()
} }
...@@ -118,7 +119,7 @@ type DumpOptions struct { ...@@ -118,7 +119,7 @@ type DumpOptions struct {
Precision int Precision int
} }
func Dump(t Tensor, opts ...DumpOptions) string { func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
if len(opts) < 1 { if len(opts) < 1 {
opts = append(opts, DumpOptions{ opts = append(opts, DumpOptions{
Items: 3, Items: 3,
...@@ -128,11 +129,17 @@ func Dump(t Tensor, opts ...DumpOptions) string { ...@@ -128,11 +129,17 @@ func Dump(t Tensor, opts ...DumpOptions) string {
switch t.DType() { switch t.DType() {
case DTypeF32: case DTypeF32:
return dump[[]float32](t, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeF16:
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
}) })
case DTypeI32: case DTypeI32:
return dump[[]int32](t, opts[0].Items, func(i int32) string { return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
return strconv.FormatInt(int64(i), 10) return strconv.FormatInt(int64(i), 10)
}) })
default: default:
...@@ -140,10 +147,10 @@ func Dump(t Tensor, opts ...DumpOptions) string { ...@@ -140,10 +147,10 @@ func Dump(t Tensor, opts ...DumpOptions) string {
} }
} }
func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string { func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
bts := t.Bytes() if t.Bytes() == nil {
if bts == nil { ctx.Forward(t)
return "<nil>" ctx.Compute(t)
} }
s := make(S, mul(t.Shape()...)) s := make(S, mul(t.Shape()...))
...@@ -191,7 +198,8 @@ func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string { ...@@ -191,7 +198,8 @@ func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string {
type DType int type DType int
const ( const (
DTypeF32 DType = iota DTypeOther DType = iota
DTypeF32
DTypeF16
DTypeI32 DTypeI32
DTypeOther
) )
...@@ -258,6 +258,10 @@ func (c *Context) Compute(tensors ...ml.Tensor) { ...@@ -258,6 +258,10 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
} }
} }
func (c *Context) MaxTensors() int {
return c.nodes
}
func shapeToGGML(shape []int) *C.int64_t { func shapeToGGML(shape []int) *C.int64_t {
sh := make([]C.int64_t, len(shape)) sh := make([]C.int64_t, len(shape))
for i, s := range shape { for i, s := range shape {
...@@ -282,6 +286,8 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { ...@@ -282,6 +286,8 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
switch dtype { switch dtype {
case ml.DTypeF32: case ml.DTypeF32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeF16:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32: case ml.DTypeI32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default: default:
...@@ -389,6 +395,8 @@ func (t *Tensor) DType() ml.DType { ...@@ -389,6 +395,8 @@ func (t *Tensor) DType() ml.DType {
switch t.t._type { switch t.t._type {
case C.GGML_TYPE_F32: case C.GGML_TYPE_F32:
return ml.DTypeF32 return ml.DTypeF32
case C.GGML_TYPE_F16:
return ml.DTypeF16
case C.GGML_TYPE_I32: case C.GGML_TYPE_I32:
return ml.DTypeI32 return ml.DTypeI32
default: default:
...@@ -580,9 +588,14 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi ...@@ -580,9 +588,14 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
ropeFactors = &Tensor{} ropeFactors = &Tensor{}
} }
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{ return &Tensor{
t: C.ggml_rope_ext( t: C.ggml_rope_ext(
ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim), C.int(ropeDim),
131072, // YaRN n_ctx_train 131072, // YaRN n_ctx_train
ropeTypeNorm, // ROPE_TYPE_NORM ropeTypeNorm, // ROPE_TYPE_NORM
......
package model package model
import ( import (
"errors"
"fmt" "fmt"
"image" "image"
_ "image/jpeg" _ "image/jpeg"
...@@ -15,102 +16,42 @@ import ( ...@@ -15,102 +16,42 @@ import (
_ "golang.org/x/image/tiff" _ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
"github.com/ollama/ollama/cache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
) )
type Cache struct {
cache.Cache
cache.Options
}
func (c Cache) Sub(i int) Cache {
if c.Cache != nil {
return Cache{
Cache: c.Cache.Sub(i),
Options: c.Options,
}
}
return c
}
func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
if c.Cache != nil {
return c.Cache.Put(ctx, key, value, opts)
}
return key, value
}
type Options struct { type Options struct {
inputs []int32 Inputs []int32
Positions []int32
Offset int Sequences []int
Outputs []int32
Images []image.Image Images []image.Image
Cache
}
func (opts Options) Inputs() []int32 {
return opts.inputs[opts.Offset:]
}
func (opts Options) Positions() []int32 {
positions := make([]int32, len(opts.inputs)-opts.Offset)
for i := range positions {
positions[i] = int32(opts.Offset + i)
}
return positions
} }
type OptionsFunc func(Model, *Options) type config struct {
Cache kvcache.Cache
func WithInputIDs(ids []int32) OptionsFunc {
return func(m Model, opts *Options) {
opts.inputs = ids
}
}
func WithOffset(offset int) OptionsFunc {
return func(m Model, opts *Options) {
opts.Offset = offset
opts.Cache.Position = offset
}
}
func WithImage(img image.Image) OptionsFunc {
return func(m Model, opts *Options) {
opts.Images = append(opts.Images, img)
}
}
func WithCache(c cache.Cache) OptionsFunc {
return func(m Model, opts *Options) {
opts.Cache = Cache{
Cache: c,
Options: cache.Options{
Position: opts.Offset,
},
}
}
} }
type Base struct { type Base struct {
b ml.Backend b ml.Backend
config
} }
func (m *Base) Backend() ml.Backend { func (m *Base) Backend() ml.Backend {
return m.b return m.b
} }
func (m *Base) Config() config {
return m.config
}
type Model interface { type Model interface {
Forward(ml.Context, Options) (ml.Tensor, error) Forward(ml.Context, Options) (ml.Tensor, error)
Backend() ml.Backend Backend() ml.Backend
Config() config
} }
var models = make(map[string]func(ml.Config) (Model, error)) var models = make(map[string]func(ml.Config) (Model, error))
...@@ -146,12 +87,14 @@ func New(s string) (Model, error) { ...@@ -146,12 +87,14 @@ func New(s string) (Model, error) {
return nil, err return nil, err
} }
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m) v := reflect.ValueOf(m)
v.Elem().Set(populateFields(b, v.Elem())) v.Elem().Set(populateFields(base, v.Elem()))
return m, nil return m, nil
} }
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
...@@ -170,7 +113,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { ...@@ -170,7 +113,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
} }
if tt == reflect.TypeOf((*Base)(nil)).Elem() { if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(Base{b: b})) vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag) [][]string var fn func([]Tag) [][]string
fn = func(tags []Tag) (values [][]string) { fn = func(tags []Tag) (values [][]string) {
...@@ -196,21 +139,21 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { ...@@ -196,21 +139,21 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
names := fn(tagsCopy) names := fn(tagsCopy)
for _, name := range names { for _, name := range names {
if tensor := b.Get(strings.Join(name, ".")); tensor != nil { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
slog.Debug("found tensor", "", tensor) slog.Debug("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor)) vv.Set(reflect.ValueOf(tensor))
break break
} }
} }
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
setPointer(b, vv, tagsCopy) setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() { for i := range vv.Len() {
vvv := vv.Index(i) vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
} else { } else {
vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
} }
} }
} }
...@@ -228,7 +171,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { ...@@ -228,7 +171,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
return v return v
} }
func setPointer(b ml.Backend, v reflect.Value, tags []Tag) { func setPointer(base Base, v reflect.Value, tags []Tag) {
vv := v vv := v
if v.Kind() == reflect.Interface { if v.Kind() == reflect.Interface {
if v.IsNil() { if v.IsNil() {
...@@ -243,7 +186,7 @@ func setPointer(b ml.Backend, v reflect.Value, tags []Tag) { ...@@ -243,7 +186,7 @@ func setPointer(b ml.Backend, v reflect.Value, tags []Tag) {
vv = reflect.New(v.Type().Elem()).Elem() vv = reflect.New(v.Type().Elem()).Elem()
} }
if f := populateFields(b, vv, tags...); f.CanAddr() { if f := populateFields(base, vv, tags...); f.CanAddr() {
v.Set(f.Addr()) v.Set(f.Addr())
} }
} }
...@@ -277,18 +220,27 @@ func canNil(t reflect.Type) bool { ...@@ -277,18 +220,27 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice t.Kind() == reflect.Slice
} }
func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) { func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
var opts Options if len(opts.Positions) != len(opts.Sequences) {
for _, optsFunc := range optsFuncs { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
optsFunc(m, &opts) }
if len(opts.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
if err != nil {
return nil, err
}
} }
ctx := m.Backend().NewContext()
t, err := m.Forward(ctx, opts) t, err := m.Forward(ctx, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer ctx.Close()
ctx.Forward(t) ctx.Forward(t)
ctx.Compute(t) ctx.Compute(t)
......
...@@ -78,7 +78,7 @@ func TestPopulateFields(t *testing.T) { ...@@ -78,7 +78,7 @@ func TestPopulateFields(t *testing.T) {
var m fakeModel var m fakeModel
v := reflect.ValueOf(&m) v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(&fakeBackend{ v.Elem().Set(populateFields(Base{b: &fakeBackend{
names: []string{ names: []string{
"input.weight", "input.weight",
"blk.0.attn_q.weight", "blk.0.attn_q.weight",
...@@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) { ...@@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) {
"output_norm.weight", "output_norm.weight",
"output.weight", "output.weight",
}, },
}, v.Elem())) }}, v.Elem()))
if diff := cmp.Diff(fakeModel{ if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}}, Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
...@@ -121,11 +121,11 @@ func TestPopulateFieldsAlternateName(t *testing.T) { ...@@ -121,11 +121,11 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
m := fakeModel{} m := fakeModel{}
v := reflect.ValueOf(&m) v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(&fakeBackend{ v.Elem().Set(populateFields(Base{b: &fakeBackend{
names: []string{ names: []string{
"input.weight", "input.weight",
}, },
}, v.Elem())) }}, v.Elem()))
if diff := cmp.Diff(fakeModel{ if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}}, Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
......
...@@ -3,6 +3,7 @@ package llama ...@@ -3,6 +3,7 @@ package llama
import ( import (
"math" "math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
...@@ -28,7 +29,7 @@ type Model struct { ...@@ -28,7 +29,7 @@ type Model struct {
} }
func New(c ml.Config) (model.Model, error) { func New(c ml.Config) (model.Model, error) {
return &Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
...@@ -49,7 +50,11 @@ func New(c ml.Config) (model.Model, error) { ...@@ -49,7 +50,11 @@ func New(c ml.Config) (model.Model, error) {
ropeScale: c.Float("rope.freq_scale", 1), ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"), ropeDim: c.Uint("rope.dimension_count"),
}, },
}, nil }
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
} }
type SelfAttention struct { type SelfAttention struct {
...@@ -59,7 +64,7 @@ type SelfAttention struct { ...@@ -59,7 +64,7 @@ type SelfAttention struct {
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
} }
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
...@@ -74,7 +79,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten ...@@ -74,7 +79,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k, v = cache.Put(ctx, k, v, cache.Options) cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
...@@ -82,6 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten ...@@ -82,6 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
kq := k.MulmatFullPrec(ctx, q) kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx) kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq) kqv := v.Mulmat(ctx, kq)
...@@ -91,6 +98,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten ...@@ -91,6 +98,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
return sa.Output.Forward(ctx, kqv) return sa.Output.Forward(ctx, kqv)
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct { type MLP struct {
Up *nn.Linear `gguf:"ffn_up"` Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"` Down *nn.Linear `gguf:"ffn_down"`
...@@ -109,7 +120,7 @@ type Layer struct { ...@@ -109,7 +120,7 @@ type Layer struct {
MLP *MLP MLP *MLP
} }
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
...@@ -123,12 +134,12 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach ...@@ -123,12 +134,12 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach
} }
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs())) inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -136,13 +147,14 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { ...@@ -136,13 +147,14 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers { for i, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options) m.Cache.SetLayer(i)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
} }
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState) hiddenState = m.Output.Forward(ctx, hiddenState)
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
package mllama package mllama
import ( import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
...@@ -18,8 +19,13 @@ type Model struct { ...@@ -18,8 +19,13 @@ type Model struct {
ImageProcessor ImageProcessor
} }
const (
crossAttentionLayer = iota
selfAttentionLayer
)
func New(c ml.Config) (model.Model, error) { func New(c ml.Config) (model.Model, error) {
return &Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
...@@ -33,7 +39,11 @@ func New(c ml.Config) (model.Model, error) { ...@@ -33,7 +39,11 @@ func New(c ml.Config) (model.Model, error) {
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),
TextModel: newTextModel(c), TextModel: newTextModel(c),
}, nil }
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil
} }
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
...@@ -73,20 +83,20 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { ...@@ -73,20 +83,20 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates) crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
} }
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs())) inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: attention mask, cross attention mask // TODO: attention mask, cross attention mask
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache) hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -4,9 +4,9 @@ import ( ...@@ -4,9 +4,9 @@ import (
"math" "math"
"slices" "slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
) )
type TextSelfAttention struct { type TextSelfAttention struct {
...@@ -16,7 +16,7 @@ type TextSelfAttention struct { ...@@ -16,7 +16,7 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
} }
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
...@@ -31,7 +31,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas ...@@ -31,7 +31,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key, value = cache.Put(ctx, key, value, cache.Options) cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
...@@ -39,11 +40,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas ...@@ -39,11 +40,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
scores := key.MulmatFullPrec(ctx, query) scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
if mask != nil {
scores = scores.Add(ctx, mask) scores = scores.Add(ctx, mask)
}
scores = scores.Softmax(ctx) scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores) attention := value.Mulmat(ctx, scores)
...@@ -53,6 +50,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas ...@@ -53,6 +50,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type TextMLP struct { type TextMLP struct {
Up *nn.Linear `gguf:"ffn_up"` Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"` Down *nn.Linear `gguf:"ffn_down"`
...@@ -72,7 +74,7 @@ type TextSelfAttentionDecoderLayer struct { ...@@ -72,7 +74,7 @@ type TextSelfAttentionDecoderLayer struct {
MLP *TextMLP MLP *TextMLP
} }
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
...@@ -94,23 +96,29 @@ type TextCrossAttention struct { ...@@ -94,23 +96,29 @@ type TextCrossAttention struct {
Output *nn.Linear `gguf:"cross_attn_o_proj"` Output *nn.Linear `gguf:"cross_attn_o_proj"`
} }
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
query := ca.Query.Forward(ctx, hiddenState) query := ca.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps) query = ca.QueryNorm.Forward(ctx, query, opts.eps)
key := ca.Key.Forward(ctx, crossAttentionStates) var key, value ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
key = ca.Key.Forward(ctx, crossAttentionStates)
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
key = ca.KeyNorm.Forward(ctx, key, opts.eps) key = ca.KeyNorm.Forward(ctx, key, opts.eps)
value := ca.Value.Forward(ctx, crossAttentionStates) value = ca.Value.Forward(ctx, crossAttentionStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
// TODO cache key, value cache.Put(ctx, key, value)
} else {
key, value, _ = cache.Get(ctx)
}
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
...@@ -137,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct { ...@@ -137,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct {
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
} }
func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
...@@ -153,17 +161,25 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, ...@@ -153,17 +161,25 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
} }
type TextDecoderLayer interface { type TextDecoderLayer interface {
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
} }
type TextDecoder struct { type TextDecoder struct {
Layers []TextDecoderLayer Layers []TextDecoderLayer
} }
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers { for i, layer := range d.Layers {
if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil { layerType := selfAttentionLayer
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts) if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
layerType = crossAttentionLayer
}
cache.SetLayer(i)
cache.SetLayerType(layerType)
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
} }
} }
...@@ -189,7 +205,7 @@ type TextModel struct { ...@@ -189,7 +205,7 @@ type TextModel struct {
*TextModelOptions *TextModelOptions
} }
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
......
package runner package common
import ( import (
"strings" "strings"
) )
func findStop(sequence string, stops []string) (bool, string) { func FindStop(sequence string, stops []string) (bool, string) {
for _, stop := range stops { for _, stop := range stops {
if strings.Contains(sequence, stop) { if strings.Contains(sequence, stop) {
return true, stop return true, stop
...@@ -14,7 +14,7 @@ func findStop(sequence string, stops []string) (bool, string) { ...@@ -14,7 +14,7 @@ func findStop(sequence string, stops []string) (bool, string) {
return false, "" return false, ""
} }
func containsStopSuffix(sequence string, stops []string) bool { func ContainsStopSuffix(sequence string, stops []string) bool {
for _, stop := range stops { for _, stop := range stops {
for i := 1; i <= len(stop); i++ { for i := 1; i <= len(stop); i++ {
if strings.HasSuffix(sequence, stop[:i]) { if strings.HasSuffix(sequence, stop[:i]) {
...@@ -29,7 +29,7 @@ func containsStopSuffix(sequence string, stops []string) bool { ...@@ -29,7 +29,7 @@ func containsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces, // truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating // returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case) // the last piece if required (and signalling if this was the case)
func truncateStop(pieces []string, stop string) ([]string, bool) { func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "") joined := strings.Join(pieces, "")
index := strings.Index(joined, stop) index := strings.Index(joined, stop)
...@@ -65,7 +65,7 @@ func truncateStop(pieces []string, stop string) ([]string, bool) { ...@@ -65,7 +65,7 @@ func truncateStop(pieces []string, stop string) ([]string, bool) {
return result, tokenTruncated return result, tokenTruncated
} }
func incompleteUnicode(token string) bool { func IncompleteUnicode(token string) bool {
incomplete := false incomplete := false
// check if there is incomplete UTF-8 character at the end // check if there is incomplete UTF-8 character at the end
......
package runner package common
import ( import (
"reflect" "reflect"
...@@ -52,7 +52,7 @@ func TestTruncateStop(t *testing.T) { ...@@ -52,7 +52,7 @@ func TestTruncateStop(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := truncateStop(tt.pieces, tt.stop) result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
} }
...@@ -120,7 +120,7 @@ func TestIncompleteUnicode(t *testing.T) { ...@@ -120,7 +120,7 @@ func TestIncompleteUnicode(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result := incompleteUnicode(tt.input) result := IncompleteUnicode(tt.input)
if result != tt.expected { if result != tt.expected {
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected) t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment