model_text.go 9.17 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
package gemma3

import (
	"math"

6
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
7
8
9
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
10
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
11
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
12
13
)

14
type TextConfig struct {
15
16
17
18
19
20
21
22
23
24
25
26
27
28
	hiddenSize, contextLength, numHeads, numKVHeads int
	attnKeyLen, attnValLen                          int
	eps, ropeScale                                  float32
	ropeLocalBase                                   float32
	largeModelScaling                               bool
	slidingWindow                                   uint32
	slidingWindowPattern                            []bool
	ropeBase                                        float32
	ropeType                                        string
	ropeOriginalContext                             int
	ropeExtrapolation                               float32
	ropeBetaFast                                    float32
	ropeBetaSlow                                    float32
	finalLogitSoftcap                               float32
Patrick Devine's avatar
Patrick Devine committed
29
30
}

31
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
32
33
	ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
	if o.ropeType == "yarn" {
34
		attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
35
36
37
38
39
40
41
42
43
		ropeOpts = append(ropeOpts,
			rope.WithOriginalContextLength(o.ropeOriginalContext),
			rope.WithExtrapolationFactor(o.ropeExtrapolation),
			rope.WithAttentionFactor(attnFactor),
			rope.WithBetaFast(o.ropeBetaFast),
			rope.WithBetaSlow(o.ropeBetaSlow),
		)
	}

44
	return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
Michael Yang's avatar
Michael Yang committed
45
46
}

Patrick Devine's avatar
Patrick Devine committed
47
48
49
50
51
52
type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

53
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
54
55
56
}

const (
Patrick Devine's avatar
Patrick Devine committed
57
	gemmaGlobalCacheCount = 6
58
59
60
	gemma1BLayerCount     = 26
	gemma4BLayerCount     = 34
	gemma12BLayerCount    = 48
Patrick Devine's avatar
Patrick Devine committed
61
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
62
63
64
65
66
67
68
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

69
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
70
71
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
72
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
73
		Layers: make([]TextLayer, numBlocks),
74
		TextConfig: &TextConfig{
75
			hiddenSize:           int(c.Uint("embedding_length")),
76
			contextLength:        int(c.Uint("context_length")),
77
78
79
80
81
82
83
			numHeads:             int(c.Uint("attention.head_count")),
			numKVHeads:           int(c.Uint("attention.head_count_kv")),
			attnKeyLen:           int(c.Uint("attention.key_length", 256)),
			attnValLen:           int(c.Uint("attention.value_length", 256)),
			eps:                  c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:        c.Float("rope.local.freq_base", 10000.0),
			ropeBase:             c.Float("rope.freq_base", 1000000.0),
84
			slidingWindow:        c.Uint("attention.sliding_window"),
85
86
87
88
89
90
			slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
			ropeType:             c.String("rope.scaling.type"),
			ropeOriginalContext:  int(c.Uint("rope.scaling.original_context_length")),
			ropeExtrapolation:    c.Float("rope.scaling.extrapolation_factor", 1.0),
			ropeBetaFast:         c.Float("rope.scaling.beta_fast", 64.0),
			ropeBetaSlow:         c.Float("rope.scaling.beta_slow", 1.0),
91
			ropeScale:            c.Float("rope.scaling.factor", 1.0),
92
			finalLogitSoftcap:    c.Float("final_logit_softcapping", 0.0),
Patrick Devine's avatar
Patrick Devine committed
93
94
95
		},
	}

96
97
98
99
100
101
102
103
104
105
106
107
108
109
	// Apply corrections for older versions of the Gemma 3 models
	// by looking at whether they use sliding window attention and
	// based on their layer counts.
	if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
		switch numBlocks {
		case gemma1BLayerCount:
			// The 1B model has final logit softcapping set to 30.0
			// but it should be 0.0
			m.TextConfig.finalLogitSoftcap = 0.0
		case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
			// The 4B, 12B, and 27B models have rope scale unset
			// but it shuold be set to 8.0
			m.TextConfig.ropeScale = 8.0
		}
110
111
	}

Patrick Devine's avatar
Patrick Devine committed
112
113
114
115
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
116
117
118
119
120
121
122
123
124
125
126
127
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

128
func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
129
	if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
130
		return opts.ropeLocalBase, 1.0
131
132
133
134
135
136
	}

	// Standard Gemma3: only every n-th layer is global,
	// where n = gemmaGlobalCacheCount, otherwise use
	// the local rope base
	if (layer+1)%gemmaGlobalCacheCount > 0 {
137
		return opts.ropeLocalBase, 1.0
138
139
140
	}

	// default to global rope base
141
	return opts.ropeBase, opts.ropeScale
142
143
}

144
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
145
146
	batchSize := hiddenState.Dim(1)

147
	ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
Patrick Devine's avatar
Patrick Devine committed
148
149
150
151

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
152
	q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
Patrick Devine's avatar
Patrick Devine committed
153
154
155
156
157
158
159
160
161
162

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
163
	k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
Patrick Devine's avatar
Patrick Devine committed
164
165
166
167
168
169
170
171
172
173
174
175

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
176
177
	ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
	return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
Patrick Devine's avatar
Patrick Devine committed
178
179
180
181
182
183
184
185
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

186
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
187
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
Patrick Devine's avatar
Patrick Devine committed
188
189
190
191
192
193
194
195
196
197
198
199
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

200
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
201
202
203
204
205
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
206
207
208
209
210
211
212
213

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
214
215
216
217
218
219
220
221
222
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Michael Yang's avatar
Michael Yang committed
223
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
224
	positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
225
226

	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
227
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
228

229
230
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
231
	for _, image := range batch.Multimodal {
232
		visionOutputs := image.Multimodal[0].Tensor
233
234
235
236
237
238
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
239

Patrick Devine's avatar
Patrick Devine committed
240
241
242
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
243
244
245
246
247
248
249
250
251
252
253
254
		if cache != nil {
			cacheType := cacheTypeSWA
			if (i+1)%gemmaGlobalCacheCount == 0 {
				cacheType = cacheTypeCausal
			}
			cache.SetLayer(i)
			wc := cache.(*kvcache.WrapperCache)
			wc.SetLayerType(cacheType)

			if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
				causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
			}
255
256
		}

Jesse Gross's avatar
Jesse Gross committed
257
258
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
259
			lastLayerOutputs = batch.Outputs
Jesse Gross's avatar
Jesse Gross committed
260
261
		}

262
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
263
264
	}

265
	return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
Patrick Devine's avatar
Patrick Devine committed
266
}