model.go 4.37 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package mllama

import (
4
5
	"bytes"
	"encoding/binary"
6
	"fmt"
7
8
9
	"hash/fnv"
	"image"
	"slices"
10

11
	"github.com/ollama/ollama/fs"
Jesse Gross's avatar
Jesse Gross committed
12
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
13
14
15
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
16
	"github.com/ollama/ollama/model/input"
Michael Yang's avatar
Michael Yang committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
)

type Model struct {
	model.Base
	model.BytePairEncoding

	*VisionModel `gguf:"v,vision"`
	*TextModel

	Projector *nn.Linear `gguf:"mm.0"`

	ImageProcessor
}

Jesse Gross's avatar
Jesse Gross committed
31
32
33
34
35
const (
	crossAttentionLayer = iota
	selfAttentionLayer
)

36
func New(c fs.Config) (model.Model, error) {
37
38
39
40
	// Verify unified config
	if c.Uint("vision.block_count") == 0 {
		return nil, fmt.Errorf("non-unified vision model not supported")
	}
Jesse Gross's avatar
Jesse Gross committed
41
	m := Model{
Michael Yang's avatar
Michael Yang committed
42
43
44
45
46
47
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
48
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
49
				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
50
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
51
				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
Michael Yang's avatar
Michael Yang committed
52
53
54
55
56
			},
		),
		ImageProcessor: newImageProcessor(c),
		VisionModel:    newVisionModel(c),
		TextModel:      newTextModel(c),
Jesse Gross's avatar
Jesse Gross committed
57
58
	}

59
60
61
	encoderCache := kvcache.NewEncoderCache()
	encoderCache.SetConfig(ml.CacheConfig{})
	m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
Jesse Gross's avatar
Jesse Gross committed
62
63

	return &m, nil
Michael Yang's avatar
Michael Yang committed
64
65
}

66
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
67
68
69
70
	if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
		return nil, model.ErrNoVisionModel
	}

71
72
73
74
	image, _, err := image.Decode(bytes.NewReader(multimodalData))
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
75

76
77
78
79
	f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
80

81
	pixelValues, err := ctx.Input().FromFloatSlice(f32s,
82
83
84
85
86
87
88
89
		m.ImageProcessor.imageSize,
		m.ImageProcessor.imageSize,
		m.ImageProcessor.numChannels,
		m.ImageProcessor.maxNumTiles,
	)
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
90

91
	aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
92
93
94
95
96
97
98
99
100
	if err != nil {
		return nil, err
	}

	positions := make([]int32, 1601)
	for i := range positions {
		positions[i] = int32(i)
	}

101
	positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
102
103
104
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
105

106
107
108
109
	crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
	return m.Projector.Forward(ctx, crossAttentionStates), nil
}

110
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
111
	var images []input.Input
112
113
114
115
116
	fnvHash := fnv.New64a()

	for i := range inputs {
		if inputs[i].Multimodal == nil {
			if len(images) > 0 {
117
				inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
118
119
				inputs[i].MultimodalHash = images[0].MultimodalHash
				for j := 1; j < len(images); j++ {
120
					inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
121
122
123
124
125
126
127
128
129
130
					fnvHash.Reset()
					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
					binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
					inputs[i].MultimodalHash = fnvHash.Sum64()
				}
				images = nil
			}
		} else {
			images = append(images, inputs[i])
			inputs[i].Token = -1
Michael Yang's avatar
Michael Yang committed
131
		}
132
133
	}

134
	inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
Michael Yang's avatar
Michael Yang committed
135

136
137
138
	return inputs, nil
}

Jesse Gross's avatar
Jesse Gross committed
139
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
140
	var crossAttentionStates ml.Tensor
Jesse Gross's avatar
Jesse Gross committed
141
142
	if len(batch.Multimodal) > 0 {
		images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
143
144
145
		if len(images) > 0 {
			crossAttentionStates = images[len(images)-1]
		}
Michael Yang's avatar
Michael Yang committed
146
147
	}

Jesse Gross's avatar
Jesse Gross committed
148
	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
149
150
151
152
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
153
	outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
Michael Yang's avatar
Michael Yang committed
154
155
156
157
	if err != nil {
		return nil, err
	}

158
	// TODO: attention mask, cross attention mask
159
	return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
Michael Yang's avatar
Michael Yang committed
160
161
162
163
164
}

func init() {
	model.Register("mllama", New)
}