model.go 4.44 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package mllama

import (
4
5
	"bytes"
	"encoding/binary"
6
	"fmt"
7
8
9
	"hash/fnv"
	"image"
	"slices"
10

Jesse Gross's avatar
Jesse Gross committed
11
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
12
13
14
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
15
	"github.com/ollama/ollama/model/input"
Michael Yang's avatar
Michael Yang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
)

type Model struct {
	model.Base
	model.BytePairEncoding

	*VisionModel `gguf:"v,vision"`
	*TextModel

	Projector *nn.Linear `gguf:"mm.0"`

	ImageProcessor
}

Jesse Gross's avatar
Jesse Gross committed
30
31
32
33
34
const (
	crossAttentionLayer = iota
	selfAttentionLayer
)

Michael Yang's avatar
Michael Yang committed
35
func New(c ml.Config) (model.Model, error) {
36
37
38
39
	// Verify unified config
	if c.Uint("vision.block_count") == 0 {
		return nil, fmt.Errorf("non-unified vision model not supported")
	}
Jesse Gross's avatar
Jesse Gross committed
40
	m := Model{
Michael Yang's avatar
Michael Yang committed
41
42
43
44
45
46
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
47
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
48
				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
49
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
50
				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
Michael Yang's avatar
Michael Yang committed
51
52
53
54
55
			},
		),
		ImageProcessor: newImageProcessor(c),
		VisionModel:    newVisionModel(c),
		TextModel:      newTextModel(c),
Jesse Gross's avatar
Jesse Gross committed
56
57
	}

58
59
60
	encoderCache := kvcache.NewEncoderCache()
	encoderCache.SetConfig(ml.CacheConfig{})
	m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
Jesse Gross's avatar
Jesse Gross committed
61
62

	return &m, nil
Michael Yang's avatar
Michael Yang committed
63
64
}

65
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
66
67
68
69
	if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
		return nil, model.ErrNoVisionModel
	}

70
71
72
73
	image, _, err := image.Decode(bytes.NewReader(multimodalData))
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
74

75
76
77
78
	f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
79

80
	pixelValues, err := ctx.Input().FromFloatSlice(f32s,
81
82
83
84
85
86
87
88
		m.ImageProcessor.imageSize,
		m.ImageProcessor.imageSize,
		m.ImageProcessor.numChannels,
		m.ImageProcessor.maxNumTiles,
	)
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
89

90
	aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
91
92
93
94
95
96
97
98
99
	if err != nil {
		return nil, err
	}

	positions := make([]int32, 1601)
	for i := range positions {
		positions[i] = int32(i)
	}

100
	positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
101
102
103
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
104

105
106
107
108
	crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
	return m.Projector.Forward(ctx, crossAttentionStates), nil
}

109
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
110
	var images []input.Input
111
112
113
114
115
	fnvHash := fnv.New64a()

	for i := range inputs {
		if inputs[i].Multimodal == nil {
			if len(images) > 0 {
116
				inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
117
118
				inputs[i].MultimodalHash = images[0].MultimodalHash
				for j := 1; j < len(images); j++ {
119
					inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
120
121
122
123
124
125
126
127
128
129
					fnvHash.Reset()
					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
					binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
					inputs[i].MultimodalHash = fnvHash.Sum64()
				}
				images = nil
			}
		} else {
			images = append(images, inputs[i])
			inputs[i].Token = -1
Michael Yang's avatar
Michael Yang committed
130
		}
131
132
	}

133
	inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
Michael Yang's avatar
Michael Yang committed
134

135
136
137
	return inputs, nil
}

138
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
139
	var crossAttentionStates ml.Tensor
140
	if len(opts.Multimodal) > 0 {
141
142
143
144
		images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
		if len(images) > 0 {
			crossAttentionStates = images[len(images)-1]
		}
Michael Yang's avatar
Michael Yang committed
145
146
	}

147
	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
Michael Yang's avatar
Michael Yang committed
148
149
150
151
	if err != nil {
		return nil, err
	}

152
	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
Michael Yang's avatar
Michael Yang committed
153
154
155
156
	if err != nil {
		return nil, err
	}

157
	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
Michael Yang's avatar
Michael Yang committed
158
159
160
161
	if err != nil {
		return nil, err
	}

162
163
	// TODO: attention mask, cross attention mask
	return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
Michael Yang's avatar
Michael Yang committed
164
165
166
167
168
}

func init() {
	model.Register("mllama", New)
}