model.go 4.25 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package mllama

import (
4
5
	"bytes"
	"encoding/binary"
6
	"fmt"
7
8
9
	"hash/fnv"
	"image"
	"slices"
10

Jesse Gross's avatar
Jesse Gross committed
11
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
12
13
14
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
15
	"github.com/ollama/ollama/model/input"
Michael Yang's avatar
Michael Yang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
)

type Model struct {
	model.Base
	model.BytePairEncoding

	*VisionModel `gguf:"v,vision"`
	*TextModel

	Projector *nn.Linear `gguf:"mm.0"`

	ImageProcessor
}

Jesse Gross's avatar
Jesse Gross committed
30
31
32
33
34
const (
	crossAttentionLayer = iota
	selfAttentionLayer
)

Michael Yang's avatar
Michael Yang committed
35
func New(c ml.Config) (model.Model, error) {
36
37
38
39
	// Verify unified config
	if c.Uint("vision.block_count") == 0 {
		return nil, fmt.Errorf("non-unified vision model not supported")
	}
Jesse Gross's avatar
Jesse Gross committed
40
	m := Model{
Michael Yang's avatar
Michael Yang committed
41
42
43
44
45
46
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
47
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
48
				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
49
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
50
				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
Michael Yang's avatar
Michael Yang committed
51
52
53
54
55
			},
		),
		ImageProcessor: newImageProcessor(c),
		VisionModel:    newVisionModel(c),
		TextModel:      newTextModel(c),
Jesse Gross's avatar
Jesse Gross committed
56
57
	}

58
59
60
	encoderCache := kvcache.NewEncoderCache()
	encoderCache.SetConfig(ml.CacheConfig{})
	m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
Jesse Gross's avatar
Jesse Gross committed
61
62

	return &m, nil
Michael Yang's avatar
Michael Yang committed
63
64
}

65
66
67
68
69
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
	image, _, err := image.Decode(bytes.NewReader(multimodalData))
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
70

71
72
73
74
	f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
75

76
	pixelValues, err := ctx.Input().FromFloatSlice(f32s,
77
78
79
80
81
82
83
84
		m.ImageProcessor.imageSize,
		m.ImageProcessor.imageSize,
		m.ImageProcessor.numChannels,
		m.ImageProcessor.maxNumTiles,
	)
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
85

86
	aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
87
88
89
90
91
92
93
94
95
	if err != nil {
		return nil, err
	}

	positions := make([]int32, 1601)
	for i := range positions {
		positions[i] = int32(i)
	}

96
	positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
97
98
99
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
100

101
102
103
104
	crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
	return m.Projector.Forward(ctx, crossAttentionStates), nil
}

105
106
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
	var images []input.Input
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
	fnvHash := fnv.New64a()

	for i := range inputs {
		if inputs[i].Multimodal == nil {
			if len(images) > 0 {
				inputs[i].Multimodal = images[0].Multimodal
				inputs[i].MultimodalHash = images[0].MultimodalHash
				for j := 1; j < len(images); j++ {
					inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
					fnvHash.Reset()
					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
					binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
					inputs[i].MultimodalHash = fnvHash.Sum64()
				}
				images = nil
			}
		} else {
			images = append(images, inputs[i])
			inputs[i].Token = -1
Michael Yang's avatar
Michael Yang committed
126
		}
127
128
	}

129
	inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
Michael Yang's avatar
Michael Yang committed
130

131
132
133
	return inputs, nil
}

134
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
135
	var crossAttentionStates ml.Tensor
136
137
	if len(opts.Multimodal) > 0 {
		crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
Michael Yang's avatar
Michael Yang committed
138
139
	}

140
	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
Michael Yang's avatar
Michael Yang committed
141
142
143
144
	if err != nil {
		return nil, err
	}

145
	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
Michael Yang's avatar
Michael Yang committed
146
147
148
149
	if err != nil {
		return nil, err
	}

150
	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
Michael Yang's avatar
Michael Yang committed
151
152
153
154
	if err != nil {
		return nil, err
	}

155
156
	// TODO: attention mask, cross attention mask
	return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
Michael Yang's avatar
Michael Yang committed
157
158
159
160
161
}

func init() {
	model.Register("mllama", New)
}