"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "50c81df4e7bcd8210351096ee1051f7255bb8dd4"
Commit 2c40c4d3 authored by Jesse Gross's avatar Jesse Gross Committed by Michael Yang
Browse files

Fix follow up images and images split across batches

parent e9527893
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"encoding/binary" "encoding/binary"
"hash/fnv" "hash/fnv"
"image" "image"
"slices"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
...@@ -99,49 +98,43 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er ...@@ -99,49 +98,43 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return visionOutputs, nil return visionOutputs, nil
} }
type imageToken struct {
embedding ml.Tensor
index int
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input var result []input.Input
fnvHash := fnv.New64a() fnvHash := fnv.New64a()
for i := range inputs { for _, inp := range inputs {
if inputs[i].Multimodal == nil { if inp.Multimodal == nil {
for j := range images { result = append(result, inp)
if j == 0 {
inputs[i].Multimodal = images[j].Multimodal
inputs[i].MultimodalHash = images[j].MultimodalHash
} else {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
}
images = nil
} else { } else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
for i := range inputs {
if inputs[i].Token == -1 {
imageInputs := []input.Input{ imageInputs := []input.Input{
{Token: 108}, // "\n\n" {Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>"" {Token: 255999}, // "<start_of_image>""
} }
result = append(result, imageInputs...)
// pad inputs with placeholders for image embeddings // add image embeddings
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 0}}, 256)...) inputMultimodal := inp.Multimodal.(ml.Tensor)
// <end_of_image>
imageInputs = append(imageInputs, input.Input{Token: 256000}) for i := range inputMultimodal.Dim(1) {
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
fnvHash.Write([]byte{byte(i)})
inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...) imageToken := imageToken{embedding: inputMultimodal, index: i}
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
}
// <end_of_image>
result = append(result, input.Input{Token: 256000})
} }
} }
return inputs, nil return result, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
...@@ -160,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { ...@@ -160,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
} }
func init() { func init() {
......
...@@ -173,25 +173,54 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, ...@@ -173,25 +173,54 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor { func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) var embedding ml.Tensor
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) var src, dst, length int
var except []int32
if multimodal != nil {
visionOutputs := multimodal[0].Multimodal.(ml.Tensor) for _, image := range multimodal {
offset := multimodal[0].Index - 1 - visionOutputs.Dim(1) imageToken := image.Multimodal.(imageToken)
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1)) imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok { except = append(except, positions[imageDst])
except := make([]int32, visionOutputs.Dim(1)) }
for i := 0; i < visionOutputs.Dim(1); i++ {
except[i] = int32(offset + i)
}
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) if embedding != nil {
} visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
} }
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
for i, layer := range m.Layers { for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global) // gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers // kv cache every 6 layers
...@@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor ...@@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
wc := cache.(*kvcache.WrapperCache) wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType) wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
var lastLayerOutputs ml.Tensor var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
lastLayerOutputs = outputs lastLayerOutputs = outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment