model_text.go 8.22 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
6
7
8
9
package gemma3

import (
	"math"

	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
Michael Yang's avatar
Michael Yang committed
10
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
)

type TextOptions struct {
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
	ropeLocalBase, ropeGlobalBase    float32
	finalLogitSoftcap                float32
	largeModelScaling                bool
}

type TextModel struct {
	model.Base
	model.SentencePieceModel

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*TextOptions
}

const (
Patrick Devine's avatar
Patrick Devine committed
35
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
36
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
37
38
39
40
41
42
43
44
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

func newTextModel(c ml.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
45
46
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
47
48
49
50
51
52
53
54
55
56
57
	m := TextModel{
		SentencePieceModel: model.NewSentencePieceModel(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Scores: c.Floats("tokenizer.ggml.scores"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
			},
		),
Patrick Devine's avatar
Patrick Devine committed
58
		Layers: make([]TextLayer, numBlocks),
Patrick Devine's avatar
Patrick Devine committed
59
60
		TextOptions: &TextOptions{
			hiddenSize:        int(c.Uint("embedding_length")),
Patrick Devine's avatar
Patrick Devine committed
61
62
			numHeads:          int(c.Uint("attention.head_count")),
			numKVHeads:        int(c.Uint("attention.head_count_kv")),
Patrick Devine's avatar
Patrick Devine committed
63
64
			attnKeyLen:        int(c.Uint("attention.key_length", 256)),
			attnValLen:        int(c.Uint("attention.value_length", 256)),
Michael Yang's avatar
Michael Yang committed
65
66
67
68
69
			eps:               c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:     c.Float("rope.local.freq_base", 10000.0),
			ropeGlobalBase:    c.Float("rope.global.freq_base", 1000000.0),
			ropeScale:         c.Float("rope.freq_scale", 1.0),
			finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0),
Patrick Devine's avatar
Patrick Devine committed
70
71
72
		},
	}

Patrick Devine's avatar
Patrick Devine committed
73
74
75
76
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
	batchSize := hiddenState.Dim(1)
	ropeType := uint32(2)

	ropeBase := opts.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
94
	if (layer+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
		ropeBase = opts.ropeGlobalBase
	}

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
	q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
	k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
	ropeBase := m.TextOptions.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
126
	if (layer+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
		ropeBase = m.TextOptions.ropeGlobalBase
	}

	return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

Jesse Gross's avatar
Jesse Gross committed
153
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
154
155
156
157
158
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
159
160
161
162
163
164
165
166

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
167
168
169
170
171
172
173
174
175
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

176
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
177
178
	var embedding ml.Tensor
	var src, dst, length int
179
	var except []int
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

	for _, image := range multimodal {
		imageToken := image.Multimodal.(imageToken)
		imageSrc := imageToken.index
		imageDst := image.Index

		if embedding == nil {
			embedding = imageToken.embedding
			src = imageSrc
			dst = imageDst
			length = 1
		} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
			src = imageSrc
			dst = imageDst
			length++
		} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
			length++
		} else {
			visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
			ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))

			embedding = imageToken.embedding
			src = imageSrc
			dst = imageDst
			length = 1
		}
206

207
		except = append(except, imageDst)
208
	}
209

210
211
212
	if embedding != nil {
		visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
Michael Yang's avatar
Michael Yang committed
213
214
	}

215
216
217
218
219
220
221
	return except
}

func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))

222
	except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
223

Patrick Devine's avatar
Patrick Devine committed
224
225
226
227
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
		cacheType := cacheTypeSWA
Patrick Devine's avatar
Patrick Devine committed
228
		if (i+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
229
230
231
232
233
			cacheType = cacheTypeCausal
		}
		cache.SetLayer(i)
		wc := cache.(*kvcache.WrapperCache)
		wc.SetLayerType(cacheType)
Jesse Gross's avatar
Jesse Gross committed
234

235
236
237
238
		if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
		}

Jesse Gross's avatar
Jesse Gross committed
239
240
241
242
243
244
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
Patrick Devine's avatar
Patrick Devine committed
245
246
247
248
249
250
251
252
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	hiddenState = m.Output.Forward(ctx, hiddenState)

	// final logit softcap
	hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap))
	hiddenState = hiddenState.Tanh(ctx)
Jesse Gross's avatar
Jesse Gross committed
253
	return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap))
Patrick Devine's avatar
Patrick Devine committed
254
}