model_text.go 6.42 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
package gemma3

import (
	"math"

6
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
7
8
9
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
10
11
	"github.com/ollama/ollama/ml/nn/fast"
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
12
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
13
14
)

15
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
16
17
18
19
20
21
22
23
24
25
26
27
28
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
	ropeLocalBase, ropeGlobalBase    float32
	largeModelScaling                bool
}

type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

29
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
30
31
32
}

const (
Patrick Devine's avatar
Patrick Devine committed
33
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
34
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
35
36
37
38
39
40
41
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

42
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
43
44
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
45
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
46
		Layers: make([]TextLayer, numBlocks),
47
		TextConfig: &TextConfig{
48
49
50
51
52
53
54
55
56
			hiddenSize:     int(c.Uint("embedding_length")),
			numHeads:       int(c.Uint("attention.head_count")),
			numKVHeads:     int(c.Uint("attention.head_count_kv")),
			attnKeyLen:     int(c.Uint("attention.key_length", 256)),
			attnValLen:     int(c.Uint("attention.value_length", 256)),
			eps:            c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:  c.Float("rope.local.freq_base", 10000.0),
			ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
			ropeScale:      c.Float("rope.freq_scale", 1.0),
Patrick Devine's avatar
Patrick Devine committed
57
58
59
		},
	}

Patrick Devine's avatar
Patrick Devine committed
60
61
62
63
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
64
65
66
67
68
69
70
71
72
73
74
75
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

76
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
77
78
79
	batchSize := hiddenState.Dim(1)

	ropeBase := opts.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
80
	if (layer+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
81
82
83
84
85
86
		ropeBase = opts.ropeGlobalBase
	}

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
87
	q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
Patrick Devine's avatar
Patrick Devine committed
88
89
90
91
92
93
94
95
96
97

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
98
	k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
Patrick Devine's avatar
Patrick Devine committed
99
100
101
102
103
104
105
106
107
108
109
110

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
111
	ropeBase := m.TextConfig.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
112
	if (layer+1)%gemmaGlobalCacheCount == 0 {
113
		ropeBase = m.TextConfig.ropeGlobalBase
Patrick Devine's avatar
Patrick Devine committed
114
115
	}

116
	return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
Patrick Devine's avatar
Patrick Devine committed
117
118
119
120
121
122
123
124
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

125
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
126
127
128
129
130
131
132
133
134
135
136
137
138
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

139
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
140
141
142
143
144
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
145
146
147
148
149
150
151
152

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
153
154
155
156
157
158
159
160
161
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Jesse Gross's avatar
Jesse Gross committed
162
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
163
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
164
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
165

166
167
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
168
	for _, image := range batch.Multimodal {
169
		visionOutputs := image.Multimodal[0].Tensor
170
171
172
173
174
175
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
176

Patrick Devine's avatar
Patrick Devine committed
177
178
179
180
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
		cacheType := cacheTypeSWA
Patrick Devine's avatar
Patrick Devine committed
181
		if (i+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
182
183
184
185
186
			cacheType = cacheTypeCausal
		}
		cache.SetLayer(i)
		wc := cache.(*kvcache.WrapperCache)
		wc.SetLayerType(cacheType)
Jesse Gross's avatar
Jesse Gross committed
187

188
189
190
191
		if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
		}

Jesse Gross's avatar
Jesse Gross committed
192
193
194
195
196
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

197
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
198
199
200
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
201
	return m.Output.Forward(ctx, hiddenState)
Patrick Devine's avatar
Patrick Devine committed
202
}