model_text.go 6.36 KB
Newer Older
Patrick Devine's avatar
Patrick Devine committed
1
2
3
4
5
package gemma3

import (
	"math"

6
	"github.com/ollama/ollama/fs"
Patrick Devine's avatar
Patrick Devine committed
7
8
9
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
Michael Yang's avatar
Michael Yang committed
10
	"github.com/ollama/ollama/model/input"
Patrick Devine's avatar
Patrick Devine committed
11
12
)

13
type TextConfig struct {
Patrick Devine's avatar
Patrick Devine committed
14
15
16
17
18
19
20
21
22
23
24
25
26
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen, attnValLen           int
	eps, ropeScale                   float32
	ropeLocalBase, ropeGlobalBase    float32
	largeModelScaling                bool
}

type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []TextLayer   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

27
	*TextConfig
Patrick Devine's avatar
Patrick Devine committed
28
29
30
}

const (
Patrick Devine's avatar
Patrick Devine committed
31
	gemmaGlobalCacheCount = 6
Patrick Devine's avatar
Patrick Devine committed
32
	gemma27BLayerCount    = 62
Patrick Devine's avatar
Patrick Devine committed
33
34
35
36
37
38
39
)

const (
	cacheTypeSWA = iota
	cacheTypeCausal
)

40
func newTextModel(c fs.Config) *TextModel {
Patrick Devine's avatar
Patrick Devine committed
41
42
	numBlocks := int(c.Uint("block_count"))

Patrick Devine's avatar
Patrick Devine committed
43
	m := TextModel{
Patrick Devine's avatar
Patrick Devine committed
44
		Layers: make([]TextLayer, numBlocks),
45
		TextConfig: &TextConfig{
46
47
48
49
50
51
52
53
54
			hiddenSize:     int(c.Uint("embedding_length")),
			numHeads:       int(c.Uint("attention.head_count")),
			numKVHeads:     int(c.Uint("attention.head_count_kv")),
			attnKeyLen:     int(c.Uint("attention.key_length", 256)),
			attnValLen:     int(c.Uint("attention.value_length", 256)),
			eps:            c.Float("attention.layer_norm_rms_epsilon", 1e-06),
			ropeLocalBase:  c.Float("rope.local.freq_base", 10000.0),
			ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
			ropeScale:      c.Float("rope.freq_scale", 1.0),
Patrick Devine's avatar
Patrick Devine committed
55
56
57
		},
	}

Patrick Devine's avatar
Patrick Devine committed
58
59
60
61
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

Patrick Devine's avatar
Patrick Devine committed
62
63
64
65
66
67
68
69
70
71
72
73
	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"attn_q"`
	QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
	Key       *nn.Linear  `gguf:"attn_k"`
	KeyNorm   *nn.RMSNorm `gguf:"attn_k_norm"`
	Value     *nn.Linear  `gguf:"attn_v"`
	Output    *nn.Linear  `gguf:"attn_output"`
}

74
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
75
76
77
78
	batchSize := hiddenState.Dim(1)
	ropeType := uint32(2)

	ropeBase := opts.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
79
	if (layer+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
		ropeBase = opts.ropeGlobalBase
	}

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
	q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)

	if opts.largeModelScaling {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	} else {
		q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	}

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
	k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)

	scaleFactor := 1.0
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
110
	ropeBase := m.TextConfig.ropeLocalBase
Patrick Devine's avatar
Patrick Devine committed
111
	if (layer+1)%gemmaGlobalCacheCount == 0 {
112
		ropeBase = m.TextConfig.ropeGlobalBase
Patrick Devine's avatar
Patrick Devine committed
113
114
	}

115
	return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
Patrick Devine's avatar
Patrick Devine committed
116
117
118
119
120
121
122
123
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

124
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
125
126
127
128
129
130
131
132
133
134
135
136
137
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention     *TextSelfAttention
	PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
	MLPNorm           *nn.RMSNorm `gguf:"ffn_norm"`
	MLP               *TextMLP
	PostMLPNorm       *nn.RMSNorm `gguf:"post_ffw_norm"`
}

138
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
Patrick Devine's avatar
Patrick Devine committed
139
140
141
142
143
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
Jesse Gross's avatar
Jesse Gross committed
144
145
146
147
148
149
150
151

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Patrick Devine's avatar
Patrick Devine committed
152
153
154
155
156
157
158
159
160
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

Jesse Gross's avatar
Jesse Gross committed
161
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
162
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
163
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
164

165
166
	// set image embeddings
	var except []int
Jesse Gross's avatar
Jesse Gross committed
167
	for _, image := range batch.Multimodal {
168
169
170
171
172
173
174
		visionOutputs := image.Multimodal.(ml.Tensor)
		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

		for i := range visionOutputs.Dim(1) {
			except = append(except, image.Index+i)
		}
	}
175

Patrick Devine's avatar
Patrick Devine committed
176
177
178
179
	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
		cacheType := cacheTypeSWA
Patrick Devine's avatar
Patrick Devine committed
180
		if (i+1)%gemmaGlobalCacheCount == 0 {
Patrick Devine's avatar
Patrick Devine committed
181
182
183
184
185
			cacheType = cacheTypeCausal
		}
		cache.SetLayer(i)
		wc := cache.(*kvcache.WrapperCache)
		wc.SetLayerType(cacheType)
Jesse Gross's avatar
Jesse Gross committed
186

187
188
189
190
		if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
			causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
		}

Jesse Gross's avatar
Jesse Gross committed
191
192
193
194
195
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

196
		hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
Patrick Devine's avatar
Patrick Devine committed
197
198
199
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
200
	return m.Output.Forward(ctx, hiddenState)
Patrick Devine's avatar
Patrick Devine committed
201
}