model_text.go 6.44 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
package gemma3

import (
	"math"

6
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
7
8
9
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
10
11
	"github.com/ollama/ollama/ml/nn/fast"
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
12
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
13
14
)

15
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
16
17
18
19
20
21
22
23
24
25
26
27
28
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
	ropeLocalBase, ropeGlobalBase    float32
	largeModelScaling                bool
}

type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

29
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
30
31
32
}

const (
Patrick Devine's avatar
Patrick Devine committed
33
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
34
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
35
36
37
38
39
40
41
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

42
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
43
44
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
45
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
46
		Layers: make([]TextLayer, numBlocks),
47
		TextConfig: &TextConfig{
48
49
50
51
52
53
54
55
56
			hiddenSize:     int(c.Uint("embedding_length")),
			numHeads:       int(c.Uint("attention.head_count")),
			numKVHeads:     int(c.Uint("attention.head_count_kv")),
			attnKeyLen:     int(c.Uint("attention.key_length", 256)),
			attnValLen:     int(c.Uint("attention.value_length", 256)),
			eps:            c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:  c.Float("rope.local.freq_base", 10000.0),
			ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
			ropeScale:      c.Float("rope.freq_scale", 1.0),
Patrick Devine's avatar
Patrick Devine committed
57
58
59
		},
	}

Patrick Devine's avatar
Patrick Devine committed
60
61
62
63
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
64
65
66
67
68
69
70
71
72
73
74
75
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

76
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
77
78
79
	batchSize := hiddenState.Dim(1)

	ropeBase := opts.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
80
	if (layer+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
81
82
83
84
85
86
		ropeBase = opts.ropeGlobalBase
	}

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
87
	q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
Patrick Devine's avatar
Patrick Devine committed
88
89
90
91
92
93
94
95
96
97

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
98
	k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
Patrick Devine's avatar
Patrick Devine committed
99
100
101
102
103
104
105
106
107
108
109
110

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
111
	ropeBase := m.TextConfig.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
112
	if (layer+1)%gemmaGlobalCacheCount == 0 {
113
		ropeBase = m.TextConfig.ropeGlobalBase
Patrick Devine's avatar
Patrick Devine committed
114
115
	}

116
	return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
Patrick Devine's avatar
Patrick Devine committed
117
118
119
120
121
122
123
124
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

125
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
126
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
Patrick Devine's avatar
Patrick Devine committed
127
128
129
130
131
132
133
134
135
136
137
138
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

139
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
140
141
142
143
144
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
145
146
147
148
149
150
151
152

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
153
154
155
156
157
158
159
160
161
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Michael Yang's avatar
Michael Yang committed
162
163
164
165
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
	positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))

	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
166
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
167

168
169
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
170
	for _, image := range batch.Multimodal {
171
		visionOutputs := image.Multimodal[0].Tensor
172
173
174
175
176
177
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
178

Patrick Devine's avatar
Patrick Devine committed
179
180
181
182
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
		cacheType := cacheTypeSWA
Patrick Devine's avatar
Patrick Devine committed
183
		if (i+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
184
185
186
187
188
			cacheType = cacheTypeCausal
		}
		cache.SetLayer(i)
		wc := cache.(*kvcache.WrapperCache)
		wc.SetLayerType(cacheType)
Jesse Gross's avatar
Jesse Gross committed
189

190
191
192
193
		if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
		}

Jesse Gross's avatar
Jesse Gross committed
194
195
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
196
			lastLayerOutputs = batch.Outputs
Jesse Gross's avatar
Jesse Gross committed
197
198
		}

199
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
200
201
202
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
Michael Yang's avatar
Michael Yang committed
203
	return hiddenState
Patrick Devine's avatar
Patrick Devine committed
204
}