model.go 2.67 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package mllama

import (
Jesse Gross's avatar
Jesse Gross committed
4
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
)

type Model struct {
	model.Base
	model.BytePairEncoding

	*VisionModel `gguf:"v,vision"`
	*TextModel

	Projector *nn.Linear `gguf:"mm.0"`

	ImageProcessor
}

Jesse Gross's avatar
Jesse Gross committed
22
23
24
25
26
const (
	crossAttentionLayer = iota
	selfAttentionLayer
)

Michael Yang's avatar
Michael Yang committed
27
func New(c ml.Config) (model.Model, error) {
Jesse Gross's avatar
Jesse Gross committed
28
	m := Model{
Michael Yang's avatar
Michael Yang committed
29
30
31
32
33
34
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
35
36
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
Michael Yang's avatar
Michael Yang committed
37
38
39
40
41
			},
		),
		ImageProcessor: newImageProcessor(c),
		VisionModel:    newVisionModel(c),
		TextModel:      newTextModel(c),
Jesse Gross's avatar
Jesse Gross committed
42
43
44
45
46
	}

	m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))

	return &m, nil
Michael Yang's avatar
Michael Yang committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
}

func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
	var crossAttentionStates ml.Tensor
	if opts.Images != nil {
		f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
		if err != nil {
			return nil, err
		}

		pixelValues, err := ctx.FromFloatSlice(f32s,
			m.ImageProcessor.imageSize,
			m.ImageProcessor.imageSize,
			m.ImageProcessor.numChannels,
			m.ImageProcessor.maxNumTiles,
		)
		if err != nil {
			return nil, err
		}

		aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
		if err != nil {
			return nil, err
		}

		positions := make([]int32, 1601)
		for i := range positions {
			positions[i] = int32(i)
		}

		positionIDs, err := ctx.FromIntSlice(positions, len(positions))
		if err != nil {
			return nil, err
		}

		crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
		crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
	}

Jesse Gross's avatar
Jesse Gross committed
86
	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
Michael Yang's avatar
Michael Yang committed
87
88
89
90
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
91
	positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
Michael Yang's avatar
Michael Yang committed
92
93
94
95
96
	if err != nil {
		return nil, err
	}

	// TODO: attention mask, cross attention mask
Jesse Gross's avatar
Jesse Gross committed
97
	hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
Michael Yang's avatar
Michael Yang committed
98

Jesse Gross's avatar
Jesse Gross committed
99
	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
Michael Yang's avatar
Michael Yang committed
100
101
102
103
104
105
106
107
108
109
	if err != nil {
		return nil, err
	}

	return hiddenState.Rows(ctx, outputs), nil
}

func init() {
	model.Register("mllama", New)
}