model_text.go 8.46 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
package gemma3

import (
	"math"
5
	"slices"
Patrick Devine's avatar
Patrick Devine committed
6

7
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
8
9
10
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
12
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
13
14
)

15
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
16
17
18
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
19
	ropeLocalBase                    float32
Patrick Devine's avatar
Patrick Devine committed
20
	largeModelScaling                bool
21
22
23
24
25
26
27
28
	slidingWindowPattern             []bool
	ropeBase                         float32
	ropeType                         string
	ropeOriginalContext              int
	ropeExtrapolation                float32
	ropeBetaFast                     float32
	ropeBetaSlow                     float32
	finalLogitSoftcap                float32
Patrick Devine's avatar
Patrick Devine committed
29
30
}

Michael Yang's avatar
Michael Yang committed
31
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
32
33
34
35
36
37
38
39
40
41
42
43
44
	ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
	if o.ropeType == "yarn" {
		attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
		ropeOpts = append(ropeOpts,
			rope.WithOriginalContextLength(o.ropeOriginalContext),
			rope.WithExtrapolationFactor(o.ropeExtrapolation),
			rope.WithAttentionFactor(attnFactor),
			rope.WithBetaFast(o.ropeBetaFast),
			rope.WithBetaSlow(o.ropeBetaSlow),
		)
	}

	return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
Michael Yang's avatar
Michael Yang committed
45
46
}

Patrick Devine's avatar
Patrick Devine committed
47
48
49
50
51
52
type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

53
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
54
55
56
}

const (
Patrick Devine's avatar
Patrick Devine committed
57
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
58
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
59
60
61
62
63
64
65
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

66
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
67
68
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
69
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
70
		Layers: make([]TextLayer, numBlocks),
71
		TextConfig: &TextConfig{
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
			hiddenSize:           int(c.Uint("embedding_length")),
			numHeads:             int(c.Uint("attention.head_count")),
			numKVHeads:           int(c.Uint("attention.head_count_kv")),
			attnKeyLen:           int(c.Uint("attention.key_length", 256)),
			attnValLen:           int(c.Uint("attention.value_length", 256)),
			eps:                  c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:        c.Float("rope.local.freq_base", 10000.0),
			ropeBase:             c.Float("rope.freq_base", 1000000.0),
			slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
			ropeType:             c.String("rope.scaling.type"),
			ropeOriginalContext:  int(c.Uint("rope.scaling.original_context_length")),
			ropeExtrapolation:    c.Float("rope.scaling.extrapolation_factor", 1.0),
			ropeBetaFast:         c.Float("rope.scaling.beta_fast", 64.0),
			ropeBetaSlow:         c.Float("rope.scaling.beta_slow", 1.0),
			ropeScale:            c.Float("rope.scaling.factor", 1.0),
			finalLogitSoftcap:    c.Float("final_logit_softcapping", 0.0),
Patrick Devine's avatar
Patrick Devine committed
88
89
90
		},
	}

91
92
93
94
95
96
97
98
99
100
	// Google's Gemma 3 release with sliding window attention does
	// not use final logit softcapping, and so force it to 0.0
	// TODO (jmorganca): this should ideally be set to 0.0 in the
	// model configuration instead of here, as future versions of
	// models may include both sliding window attention and final
	// logit softcapping.
	if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
		m.TextConfig.finalLogitSoftcap = 0.0
	}

Patrick Devine's avatar
Patrick Devine committed
101
102
103
104
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
105
106
107
108
109
110
111
112
113
114
115
116
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
	if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
		return opts.ropeLocalBase
	}

	// Standard Gemma3: only every n-th layer is global,
	// where n = gemmaGlobalCacheCount, otherwise use
	// the local rope base
	if (layer+1)%gemmaGlobalCacheCount > 0 {
		return opts.ropeLocalBase
	}

	// default to global rope base
	return opts.ropeBase
}

133
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
134
135
	batchSize := hiddenState.Dim(1)

136
	ropeBase := opts.ropeBaseForLayer(layer)
Patrick Devine's avatar
Patrick Devine committed
137
138
139
140

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
Michael Yang's avatar
Michael Yang committed
141
	q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
Patrick Devine's avatar
Patrick Devine committed
142
143
144
145
146
147
148
149
150
151

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
Michael Yang's avatar
Michael Yang committed
152
	k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
Patrick Devine's avatar
Patrick Devine committed
153
154
155
156
157
158
159
160
161
162
163
164

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
165
	return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
Patrick Devine's avatar
Patrick Devine committed
166
167
168
169
170
171
172
173
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

174
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
175
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
Patrick Devine's avatar
Patrick Devine committed
176
177
178
179
180
181
182
183
184
185
186
187
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

188
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
189
190
191
192
193
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
194
195
196
197
198
199
200
201

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
202
203
204
205
206
207
208
209
210
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Michael Yang's avatar
Michael Yang committed
211
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
212
	positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
213
214

	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
215
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
216

217
218
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
219
	for _, image := range batch.Multimodal {
220
		visionOutputs := image.Multimodal[0].Tensor
221
222
223
224
225
226
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
227

Patrick Devine's avatar
Patrick Devine committed
228
229
230
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
231
232
233
234
235
236
237
238
239
240
241
242
		if cache != nil {
			cacheType := cacheTypeSWA
			if (i+1)%gemmaGlobalCacheCount == 0 {
				cacheType = cacheTypeCausal
			}
			cache.SetLayer(i)
			wc := cache.(*kvcache.WrapperCache)
			wc.SetLayerType(cacheType)

			if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
				causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
			}
243
244
		}

Jesse Gross's avatar
Jesse Gross committed
245
246
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
247
			lastLayerOutputs = batch.Outputs
Jesse Gross's avatar
Jesse Gross committed
248
249
		}

250
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
251
252
	}

253
	return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
Patrick Devine's avatar
Patrick Devine committed
254
}