model.go 2.64 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package mllama

import (
Jesse Gross's avatar
Jesse Gross committed
4
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
)

type Model struct {
	model.Base
	model.BytePairEncoding

	*VisionModel `gguf:"v,vision"`
	*TextModel

	Projector *nn.Linear `gguf:"mm.0"`

	ImageProcessor
}

Jesse Gross's avatar
Jesse Gross committed
22
23
24
25
26
const (
	crossAttentionLayer = iota
	selfAttentionLayer
)

Michael Yang's avatar
Michael Yang committed
27
func New(c ml.Config) (model.Model, error) {
Jesse Gross's avatar
Jesse Gross committed
28
	m := Model{
Michael Yang's avatar
Michael Yang committed
29
30
31
32
33
34
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
35
36
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
Michael Yang's avatar
Michael Yang committed
37
38
39
40
41
			},
		),
		ImageProcessor: newImageProcessor(c),
		VisionModel:    newVisionModel(c),
		TextModel:      newTextModel(c),
Jesse Gross's avatar
Jesse Gross committed
42
43
44
45
46
	}

	m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))

	return &m, nil
Michael Yang's avatar
Michael Yang committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
}

func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
	var crossAttentionStates ml.Tensor
	if opts.Images != nil {
		f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
		if err != nil {
			return nil, err
		}

		pixelValues, err := ctx.FromFloatSlice(f32s,
			m.ImageProcessor.imageSize,
			m.ImageProcessor.imageSize,
			m.ImageProcessor.numChannels,
			m.ImageProcessor.maxNumTiles,
		)
		if err != nil {
			return nil, err
		}

		aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
		if err != nil {
			return nil, err
		}

		positions := make([]int32, 1601)
		for i := range positions {
			positions[i] = int32(i)
		}

		positionIDs, err := ctx.FromIntSlice(positions, len(positions))
		if err != nil {
			return nil, err
		}

		crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
		crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
	}

Jesse Gross's avatar
Jesse Gross committed
86
	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
Michael Yang's avatar
Michael Yang committed
87
88
89
90
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
91
	positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
Michael Yang's avatar
Michael Yang committed
92
93
94
95
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
96
	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
Michael Yang's avatar
Michael Yang committed
97
98
99
100
	if err != nil {
		return nil, err
	}

101
102
	// TODO: attention mask, cross attention mask
	return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
Michael Yang's avatar
Michael Yang committed
103
104
105
106
107
}

func init() {
	model.Register("mllama", New)
}