model.go 3.41 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package mllama

import (
4
5
	"bytes"
	"image"
6
	"slices"
7

8
	"github.com/ollama/ollama/fs"
Jesse Gross's avatar
Jesse Gross committed
9
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
10
11
12
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
13
	"github.com/ollama/ollama/model/input"
Michael Yang's avatar
Michael Yang committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
)

type Model struct {
	model.Base
	model.BytePairEncoding

	*VisionModel `gguf:"v,vision"`
	*TextModel

	Projector *nn.Linear `gguf:"mm.0"`

	ImageProcessor
}

Jesse Gross's avatar
Jesse Gross committed
28
29
30
31
32
const (
	crossAttentionLayer = iota
	selfAttentionLayer
)

33
func New(c fs.Config) (model.Model, error) {
Jesse Gross's avatar
Jesse Gross committed
34
	m := Model{
Michael Yang's avatar
Michael Yang committed
35
36
37
38
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
Michael Yang's avatar
Michael Yang committed
39
				Types:  c.Ints("tokenizer.ggml.token_type"),
Michael Yang's avatar
Michael Yang committed
40
				Merges: c.Strings("tokenizer.ggml.merges"),
41
				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
42
				BOS:    []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
43
				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
44
45
46
47
				EOS: append(
					[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
					c.Ints("tokenizer.ggml.eos_token_ids")...,
				),
Michael Yang's avatar
Michael Yang committed
48
49
50
51
52
			},
		),
		ImageProcessor: newImageProcessor(c),
		VisionModel:    newVisionModel(c),
		TextModel:      newTextModel(c),
Jesse Gross's avatar
Jesse Gross committed
53
54
	}

55
56
57
	encoderCache := kvcache.NewEncoderCache()
	encoderCache.SetConfig(ml.CacheConfig{})
	m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
Jesse Gross's avatar
Jesse Gross committed
58
59

	return &m, nil
Michael Yang's avatar
Michael Yang committed
60
61
}

62
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
63
64
65
66
	if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
		return nil, model.ErrNoVisionModel
	}

67
68
69
70
	image, _, err := image.Decode(bytes.NewReader(multimodalData))
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
71

72
	f32s, ratio, err := m.ImageProcessor.ProcessImage(image)
73
74
75
	if err != nil {
		return nil, err
	}
Michael Yang's avatar
Michael Yang committed
76

77
78
79
80
81
82
	if ratio.numTiles() < m.maxNumTiles {
		// Pad tiles to maxNumTiles
		f32s = slices.Grow(f32s, m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles)
		f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
	}

83
84
	pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
	aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
85

Michael Yang's avatar
arange  
Michael Yang committed
86
	positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
87
	crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
88
89
90
	projectedOutputs := m.Projector.Forward(ctx, crossAttentionStates)

	return []input.Multimodal{{Tensor: projectedOutputs}}, nil
91
92
}

93
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
94
	for i := range inputs {
95
96
		if inputs[i].Multimodal != nil {
			inputs[i].Token = 128256 // <|image|>
Michael Yang's avatar
Michael Yang committed
97
		}
98
99
100
101
102
	}

	return inputs, nil
}

Jesse Gross's avatar
Jesse Gross committed
103
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
104
	var crossAttentionStates ml.Tensor
Jesse Gross's avatar
Jesse Gross committed
105
	if len(batch.Multimodal) > 0 {
106
		crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
Michael Yang's avatar
Michael Yang committed
107
108
	}

109
110
	positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
	outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
Michael Yang's avatar
Michael Yang committed
111

112
	// TODO: attention mask, cross attention mask
113
	return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
Michael Yang's avatar
Michael Yang committed
114
115
116
117
118
}

func init() {
	model.Register("mllama", New)
}