Commit 2c40c4d3 authored by Jesse Gross's avatar Jesse Gross Committed by Michael Yang
Browse files

Fix follow up images and images split across batches

parent e9527893
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"encoding/binary" "encoding/binary"
"hash/fnv" "hash/fnv"
"image" "image"
"slices"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
...@@ -99,49 +98,43 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er ...@@ -99,49 +98,43 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return visionOutputs, nil return visionOutputs, nil
} }
type imageToken struct {
embedding ml.Tensor
index int
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input var result []input.Input
fnvHash := fnv.New64a() fnvHash := fnv.New64a()
for i := range inputs { for _, inp := range inputs {
if inputs[i].Multimodal == nil { if inp.Multimodal == nil {
for j := range images { result = append(result, inp)
if j == 0 {
inputs[i].Multimodal = images[j].Multimodal
inputs[i].MultimodalHash = images[j].MultimodalHash
} else {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
}
images = nil
} else { } else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
for i := range inputs {
if inputs[i].Token == -1 {
imageInputs := []input.Input{ imageInputs := []input.Input{
{Token: 108}, // "\n\n" {Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>"" {Token: 255999}, // "<start_of_image>""
} }
result = append(result, imageInputs...)
// pad inputs with placeholders for image embeddings // add image embeddings
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 0}}, 256)...) inputMultimodal := inp.Multimodal.(ml.Tensor)
// <end_of_image>
imageInputs = append(imageInputs, input.Input{Token: 256000}) for i := range inputMultimodal.Dim(1) {
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
fnvHash.Write([]byte{byte(i)})
inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...) imageToken := imageToken{embedding: inputMultimodal, index: i}
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
}
// <end_of_image>
result = append(result, input.Input{Token: 256000})
} }
} }
return inputs, nil return result, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
...@@ -160,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { ...@@ -160,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
} }
func init() { func init() {
......
...@@ -173,25 +173,54 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, ...@@ -173,25 +173,54 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor { func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) var embedding ml.Tensor
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) var src, dst, length int
var except []int32
if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor) for _, image := range multimodal {
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1) imageToken := image.Multimodal.(imageToken)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1)) imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok { except = append(except, positions[imageDst])
except := make([]int32, visionOutputs.Dim(1)) }
for i := 0; i < visionOutputs.Dim(1); i++ {
except[i] = int32(offset + i)
}
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) if embedding != nil {
} visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
} }
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
for i, layer := range m.Layers { for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global) // gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers // kv cache every 6 layers
...@@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor ...@@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
wc := cache.(*kvcache.WrapperCache) wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType) wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
var lastLayerOutputs ml.Tensor var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
lastLayerOutputs = outputs lastLayerOutputs = outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment