model_text.go 6.68 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
package gemma3

import (
	"math"

6
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
7
8
9
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
10
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
11
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
12
13
)

14
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
15
16
17
18
19
20
21
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
	ropeLocalBase, ropeGlobalBase    float32
	largeModelScaling                bool
}

Michael Yang's avatar
Michael Yang committed
22
23
24
25
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
	return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, rope.WithTypeNeoX())
}

Patrick Devine's avatar
Patrick Devine committed
26
27
28
29
30
31
type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

32
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
33
34
35
}

const (
Patrick Devine's avatar
Patrick Devine committed
36
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
37
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
38
39
40
41
42
43
44
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

45
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
46
47
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
48
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
49
		Layers: make([]TextLayer, numBlocks),
50
		TextConfig: &TextConfig{
51
52
53
54
55
56
57
58
			hiddenSize:     int(c.Uint("embedding_length")),
			numHeads:       int(c.Uint("attention.head_count")),
			numKVHeads:     int(c.Uint("attention.head_count_kv")),
			attnKeyLen:     int(c.Uint("attention.key_length", 256)),
			attnValLen:     int(c.Uint("attention.value_length", 256)),
			eps:            c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:  c.Float("rope.local.freq_base", 10000.0),
			ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
59
60
61
62
			ropeScale:      1,
			// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
			//       (8 instead of 1)
			// ropeScale:      c.Float("rope.scaling.factor", 1.0),
Patrick Devine's avatar
Patrick Devine committed
63
64
65
		},
	}

Patrick Devine's avatar
Patrick Devine committed
66
67
68
69
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
70
71
72
73
74
75
76
77
78
79
80
81
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

82
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
83
84
85
	batchSize := hiddenState.Dim(1)

	ropeBase := opts.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
86
	if (layer+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
87
88
89
90
91
92
		ropeBase = opts.ropeGlobalBase
	}

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
Michael Yang's avatar
Michael Yang committed
93
	q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
Patrick Devine's avatar
Patrick Devine committed
94
95
96
97
98
99
100
101
102
103

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
Michael Yang's avatar
Michael Yang committed
104
	k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
Patrick Devine's avatar
Patrick Devine committed
105
106
107
108
109
110
111
112
113
114
115
116

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
117
	ropeBase := m.TextConfig.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
118
	if (layer+1)%gemmaGlobalCacheCount == 0 {
119
		ropeBase = m.TextConfig.ropeGlobalBase
Patrick Devine's avatar
Patrick Devine committed
120
121
	}

Michael Yang's avatar
Michael Yang committed
122
	return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
Patrick Devine's avatar
Patrick Devine committed
123
124
125
126
127
128
129
130
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

131
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
132
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
Patrick Devine's avatar
Patrick Devine committed
133
134
135
136
137
138
139
140
141
142
143
144
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

145
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
146
147
148
149
150
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
151
152
153
154
155
156
157
158

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
159
160
161
162
163
164
165
166
167
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Michael Yang's avatar
Michael Yang committed
168
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
169
	positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
170
171

	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
172
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
173

174
175
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
176
	for _, image := range batch.Multimodal {
177
		visionOutputs := image.Multimodal[0].Tensor
178
179
180
181
182
183
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
184

Patrick Devine's avatar
Patrick Devine committed
185
186
187
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
188
189
190
191
192
193
194
195
196
197
198
199
		if cache != nil {
			cacheType := cacheTypeSWA
			if (i+1)%gemmaGlobalCacheCount == 0 {
				cacheType = cacheTypeCausal
			}
			cache.SetLayer(i)
			wc := cache.(*kvcache.WrapperCache)
			wc.SetLayerType(cacheType)

			if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
				causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
			}
200
201
		}

Jesse Gross's avatar
Jesse Gross committed
202
203
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
204
			lastLayerOutputs = batch.Outputs
Jesse Gross's avatar
Jesse Gross committed
205
206
		}

207
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
208
209
210
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
Michael Yang's avatar
Michael Yang committed
211
	return hiddenState
Patrick Devine's avatar
Patrick Devine committed
212
}