model_text.go 8.61 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
package gemma3

import (
	"math"
5
	"slices"
Patrick Devine's avatar
Patrick Devine committed
6

7
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
8
9
10
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
12
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
13
14
)

15
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
16
17
18
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
19
	ropeLocalBase                    float32
Patrick Devine's avatar
Patrick Devine committed
20
	largeModelScaling                bool
21
22
23
24
25
26
27
28
	slidingWindowPattern             []bool
	ropeBase                         float32
	ropeType                         string
	ropeOriginalContext              int
	ropeExtrapolation                float32
	ropeBetaFast                     float32
	ropeBetaSlow                     float32
	finalLogitSoftcap                float32
Patrick Devine's avatar
Patrick Devine committed
29
30
}

Michael Yang's avatar
Michael Yang committed
31
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
32
33
34
35
36
37
38
39
40
41
42
43
44
	ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
	if o.ropeType == "yarn" {
		attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
		ropeOpts = append(ropeOpts,
			rope.WithOriginalContextLength(o.ropeOriginalContext),
			rope.WithExtrapolationFactor(o.ropeExtrapolation),
			rope.WithAttentionFactor(attnFactor),
			rope.WithBetaFast(o.ropeBetaFast),
			rope.WithBetaSlow(o.ropeBetaSlow),
		)
	}

	return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
Michael Yang's avatar
Michael Yang committed
45
46
}

Patrick Devine's avatar
Patrick Devine committed
47
48
49
50
51
52
type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

53
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
54
55
56
}

const (
Patrick Devine's avatar
Patrick Devine committed
57
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
58
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
59
60
61
62
63
64
65
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

66
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
67
68
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
69
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
70
		Layers: make([]TextLayer, numBlocks),
71
		TextConfig: &TextConfig{
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
			hiddenSize:           int(c.Uint("embedding_length")),
			numHeads:             int(c.Uint("attention.head_count")),
			numKVHeads:           int(c.Uint("attention.head_count_kv")),
			attnKeyLen:           int(c.Uint("attention.key_length", 256)),
			attnValLen:           int(c.Uint("attention.value_length", 256)),
			eps:                  c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:        c.Float("rope.local.freq_base", 10000.0),
			ropeBase:             c.Float("rope.freq_base", 1000000.0),
			slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
			ropeType:             c.String("rope.scaling.type"),
			ropeOriginalContext:  int(c.Uint("rope.scaling.original_context_length")),
			ropeExtrapolation:    c.Float("rope.scaling.extrapolation_factor", 1.0),
			ropeBetaFast:         c.Float("rope.scaling.beta_fast", 64.0),
			ropeBetaSlow:         c.Float("rope.scaling.beta_slow", 1.0),
			ropeScale:            c.Float("rope.scaling.factor", 1.0),
			finalLogitSoftcap:    c.Float("final_logit_softcapping", 0.0),
Patrick Devine's avatar
Patrick Devine committed
88
89
90
		},
	}

91
92
	// Google's Gemma 3 release with sliding window attention does
	// not use final logit softcapping, and so force it to 0.0
93
94
	// The QAT weights for Gemma 3 also included an incorrect
	// value for the rope scale, so we need to set it to 1.0 here.
95
96
97
98
99
100
	// TODO (jmorganca): this should ideally be set to 0.0 in the
	// model configuration instead of here, as future versions of
	// models may include both sliding window attention and final
	// logit softcapping.
	if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
		m.TextConfig.finalLogitSoftcap = 0.0
101
		m.TextConfig.ropeScale = 1.0
102
103
	}

Patrick Devine's avatar
Patrick Devine committed
104
105
106
107
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
108
109
110
111
112
113
114
115
116
117
118
119
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
	if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
		return opts.ropeLocalBase
	}

	// Standard Gemma3: only every n-th layer is global,
	// where n = gemmaGlobalCacheCount, otherwise use
	// the local rope base
	if (layer+1)%gemmaGlobalCacheCount > 0 {
		return opts.ropeLocalBase
	}

	// default to global rope base
	return opts.ropeBase
}

136
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
137
138
	batchSize := hiddenState.Dim(1)

139
	ropeBase := opts.ropeBaseForLayer(layer)
Patrick Devine's avatar
Patrick Devine committed
140
141
142
143

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
Michael Yang's avatar
Michael Yang committed
144
	q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
Patrick Devine's avatar
Patrick Devine committed
145
146
147
148
149
150
151
152
153
154

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
Michael Yang's avatar
Michael Yang committed
155
	k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
Patrick Devine's avatar
Patrick Devine committed
156
157
158
159
160
161
162
163
164
165
166
167

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
168
	return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
Patrick Devine's avatar
Patrick Devine committed
169
170
171
172
173
174
175
176
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

177
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
178
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
Patrick Devine's avatar
Patrick Devine committed
179
180
181
182
183
184
185
186
187
188
189
190
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

191
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
192
193
194
195
196
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
197
198
199
200
201
202
203
204

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
205
206
207
208
209
210
211
212
213
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Michael Yang's avatar
Michael Yang committed
214
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
215
	positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
216
217

	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
218
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
219

220
221
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
222
	for _, image := range batch.Multimodal {
223
		visionOutputs := image.Multimodal[0].Tensor
224
225
226
227
228
229
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
230

Patrick Devine's avatar
Patrick Devine committed
231
232
233
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
234
235
236
237
238
239
240
241
242
243
244
245
		if cache != nil {
			cacheType := cacheTypeSWA
			if (i+1)%gemmaGlobalCacheCount == 0 {
				cacheType = cacheTypeCausal
			}
			cache.SetLayer(i)
			wc := cache.(*kvcache.WrapperCache)
			wc.SetLayerType(cacheType)

			if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
				causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
			}
246
247
		}

Jesse Gross's avatar
Jesse Gross committed
248
249
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
250
			lastLayerOutputs = batch.Outputs
Jesse Gross's avatar
Jesse Gross committed
251
252
		}

253
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
254
255
	}

256
	return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
Patrick Devine's avatar
Patrick Devine committed
257
}