model_text.go 8.73 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
package gemma3

import (
	"math"
5
	"slices"
Patrick Devine's avatar
Patrick Devine committed
6

7
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
8
9
10
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
12
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
13
14
)

15
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
16
17
18
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
19
	ropeLocalBase                    float32
Patrick Devine's avatar
Patrick Devine committed
20
	largeModelScaling                bool
21
22
23
24
25
26
27
28
	slidingWindowPattern             []bool
	ropeBase                         float32
	ropeType                         string
	ropeOriginalContext              int
	ropeExtrapolation                float32
	ropeBetaFast                     float32
	ropeBetaSlow                     float32
	finalLogitSoftcap                float32
Patrick Devine's avatar
Patrick Devine committed
29
30
}

31
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
32
33
	ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
	if o.ropeType == "yarn" {
34
		attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
35
36
37
38
39
40
41
42
43
		ropeOpts = append(ropeOpts,
			rope.WithOriginalContextLength(o.ropeOriginalContext),
			rope.WithExtrapolationFactor(o.ropeExtrapolation),
			rope.WithAttentionFactor(attnFactor),
			rope.WithBetaFast(o.ropeBetaFast),
			rope.WithBetaSlow(o.ropeBetaSlow),
		)
	}

44
	return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
Michael Yang's avatar
Michael Yang committed
45
46
}

Patrick Devine's avatar
Patrick Devine committed
47
48
49
50
51
52
type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

53
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
54
55
56
}

const (
Patrick Devine's avatar
Patrick Devine committed
57
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
58
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
59
60
61
62
63
64
65
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

66
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
67
68
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
69
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
70
		Layers: make([]TextLayer, numBlocks),
71
		TextConfig: &TextConfig{
72
73
74
75
76
77
78
79
80
81
82
83
84
85
			hiddenSize:           int(c.Uint("embedding_length")),
			numHeads:             int(c.Uint("attention.head_count")),
			numKVHeads:           int(c.Uint("attention.head_count_kv")),
			attnKeyLen:           int(c.Uint("attention.key_length", 256)),
			attnValLen:           int(c.Uint("attention.value_length", 256)),
			eps:                  c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:        c.Float("rope.local.freq_base", 10000.0),
			ropeBase:             c.Float("rope.freq_base", 1000000.0),
			slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
			ropeType:             c.String("rope.scaling.type"),
			ropeOriginalContext:  int(c.Uint("rope.scaling.original_context_length")),
			ropeExtrapolation:    c.Float("rope.scaling.extrapolation_factor", 1.0),
			ropeBetaFast:         c.Float("rope.scaling.beta_fast", 64.0),
			ropeBetaSlow:         c.Float("rope.scaling.beta_slow", 1.0),
86
			ropeScale:            c.Float("rope.scaling.factor", 8.0),
87
			finalLogitSoftcap:    c.Float("final_logit_softcapping", 0.0),
Patrick Devine's avatar
Patrick Devine committed
88
89
90
		},
	}

91
92
	// Google's Gemma 3 release with sliding window attention does
	// not use final logit softcapping, and so force it to 0.0
93
94
	// The QAT weights for Gemma 3 also included an incorrect
	// value for the rope scale, so we need to set it to 1.0 here.
95
96
97
98
99
100
	// TODO (jmorganca): this should ideally be set to 0.0 in the
	// model configuration instead of here, as future versions of
	// models may include both sliding window attention and final
	// logit softcapping.
	if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
		m.TextConfig.finalLogitSoftcap = 0.0
101
		m.TextConfig.ropeScale = 1.0
102
103
	}

Patrick Devine's avatar
Patrick Devine committed
104
105
106
107
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
108
109
110
111
112
113
114
115
116
117
118
119
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

120
func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
121
	if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
122
		return opts.ropeLocalBase, 1.0
123
124
125
126
127
128
	}

	// Standard Gemma3: only every n-th layer is global,
	// where n = gemmaGlobalCacheCount, otherwise use
	// the local rope base
	if (layer+1)%gemmaGlobalCacheCount > 0 {
129
		return opts.ropeLocalBase, 1.0
130
131
132
	}

	// default to global rope base
133
	return opts.ropeBase, opts.ropeScale
134
135
}

136
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
137
138
	batchSize := hiddenState.Dim(1)

139
	ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
Patrick Devine's avatar
Patrick Devine committed
140
141
142
143

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
144
	q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
Patrick Devine's avatar
Patrick Devine committed
145
146
147
148
149
150
151
152
153
154

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
155
	k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
Patrick Devine's avatar
Patrick Devine committed
156
157
158
159
160
161
162
163
164
165
166
167

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
168
169
	ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
	return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
Patrick Devine's avatar
Patrick Devine committed
170
171
172
173
174
175
176
177
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

178
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
179
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
Patrick Devine's avatar
Patrick Devine committed
180
181
182
183
184
185
186
187
188
189
190
191
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

192
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
193
194
195
196
197
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
198
199
200
201
202
203
204
205

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
206
207
208
209
210
211
212
213
214
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Michael Yang's avatar
Michael Yang committed
215
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
216
	positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
217
218

	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
219
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
220

221
222
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
223
	for _, image := range batch.Multimodal {
224
		visionOutputs := image.Multimodal[0].Tensor
225
226
227
228
229
230
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
231

Patrick Devine's avatar
Patrick Devine committed
232
233
234
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
235
236
237
238
239
240
241
242
243
244
245
246
		if cache != nil {
			cacheType := cacheTypeSWA
			if (i+1)%gemmaGlobalCacheCount == 0 {
				cacheType = cacheTypeCausal
			}
			cache.SetLayer(i)
			wc := cache.(*kvcache.WrapperCache)
			wc.SetLayerType(cacheType)

			if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
				causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
			}
247
248
		}

Jesse Gross's avatar
Jesse Gross committed
249
250
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
251
			lastLayerOutputs = batch.Outputs
Jesse Gross's avatar
Jesse Gross committed
252
253
		}

254
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
255
256
	}

257
	return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
Patrick Devine's avatar
Patrick Devine committed
258
}