"magic_pdf/vscode:/vscode.git/clone" did not exist on "3bd0ecf16655ee5774c7c089ec6e181d11dd8004"
model_text.go 6.77 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
package gemma3

import (
	"math"

6
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
7
8
9
10
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
Michael Yang's avatar
Michael Yang committed
11
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
12
13
)

14
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
	ropeLocalBase, ropeGlobalBase    float32
	largeModelScaling                bool
}

type TextModel struct {
	model.Base
	model.SentencePieceModel

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

31
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
32
33
34
}

const (
Patrick Devine's avatar
Patrick Devine committed
35
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
36
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
37
38
39
40
41
42
43
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

44
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
45
46
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
47
48
49
50
51
52
53
54
55
56
	m := TextModel{
		SentencePieceModel: model.NewSentencePieceModel(
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Scores: c.Floats("tokenizer.ggml.scores"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
			},
		),
Patrick Devine's avatar
Patrick Devine committed
57
		Layers: make([]TextLayer, numBlocks),
58
		TextConfig: &TextConfig{
59
60
61
62
63
64
65
66
67
			hiddenSize:     int(c.Uint("embedding_length")),
			numHeads:       int(c.Uint("attention.head_count")),
			numKVHeads:     int(c.Uint("attention.head_count_kv")),
			attnKeyLen:     int(c.Uint("attention.key_length", 256)),
			attnValLen:     int(c.Uint("attention.value_length", 256)),
			eps:            c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:  c.Float("rope.local.freq_base", 10000.0),
			ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
			ropeScale:      c.Float("rope.freq_scale", 1.0),
Patrick Devine's avatar
Patrick Devine committed
68
69
70
		},
	}

Patrick Devine's avatar
Patrick Devine committed
71
72
73
74
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
75
76
77
78
79
80
81
82
83
84
85
86
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

87
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
88
89
90
91
	batchSize := hiddenState.Dim(1)
	ropeType := uint32(2)

	ropeBase := opts.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
92
	if (layer+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
		ropeBase = opts.ropeGlobalBase
	}

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
	q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
	k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
123
	ropeBase := m.TextConfig.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
124
	if (layer+1)%gemmaGlobalCacheCount == 0 {
125
		ropeBase = m.TextConfig.ropeGlobalBase
Patrick Devine's avatar
Patrick Devine committed
126
127
	}

128
	return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
Patrick Devine's avatar
Patrick Devine committed
129
130
131
132
133
134
135
136
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

137
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
138
139
140
141
142
143
144
145
146
147
148
149
150
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

151
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
152
153
154
155
156
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
157
158
159
160
161
162
163
164

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
165
166
167
168
169
170
171
172
173
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Jesse Gross's avatar
Jesse Gross committed
174
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
175
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
176
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
177

178
179
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
180
	for _, image := range batch.Multimodal {
181
182
183
184
185
186
187
		visionOutputs := image.Multimodal.(ml.Tensor)
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
188

Patrick Devine's avatar
Patrick Devine committed
189
190
191
192
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
		cacheType := cacheTypeSWA
Patrick Devine's avatar
Patrick Devine committed
193
		if (i+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
194
195
196
197
198
			cacheType = cacheTypeCausal
		}
		cache.SetLayer(i)
		wc := cache.(*kvcache.WrapperCache)
		wc.SetLayerType(cacheType)
Jesse Gross's avatar
Jesse Gross committed
199

200
201
202
203
		if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
		}

Jesse Gross's avatar
Jesse Gross committed
204
205
206
207
208
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

209
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
210
211
212
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
213
	return m.Output.Forward(ctx, hiddenState)
Patrick Devine's avatar
Patrick Devine committed
214
}