"vscode:/vscode.git/clone" did not exist on "019851d099543852130e560c6160c572d5f70b09"
model_text.go 7.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
//go:build mlx

package gemma3

import (
	"math"

	"github.com/ollama/ollama/fs"
	"github.com/ollama/ollama/x/kvcache"
	"github.com/ollama/ollama/x/ml"
	"github.com/ollama/ollama/x/ml/nn"
	"github.com/ollama/ollama/x/model/input"
)

type TextConfig struct {
	hiddenSize, numHeads, numKVHeads int
	attnKeyLen                       int
	eps, ropeScale                   float32
	ropeLocalBase, ropeGlobalBase    float32
	largeModelScaling                bool
}

type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
	Layers         []TextLayer   `gguf:"layers"`
	OutputNorm     *nn.RMSNorm   `gguf:"norm"`
	Output         *nn.Linear    `gguf:"embed_tokens"`

	*TextConfig
}

const (
	gemmaGlobalCacheCount = 6
	gemma27BLayerCount    = 62
)

// const (
// 	cacheTypeSWA = iota
// 	cacheTypeCausal
// )

func newTextModel(c fs.Config) *TextModel {
	numBlocks := int(c.Uint("block_count"))

	m := TextModel{
		Layers: make([]TextLayer, numBlocks),
		TextConfig: &TextConfig{
			hiddenSize:     int(c.Uint("embedding_length")),                    // 2560 -- config.json: text_config.hidden_size
			numHeads:       int(c.Uint("attention.head_count")),                // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
			numKVHeads:     int(c.Uint("attention.head_count_kv")),             // 4 -- same as above
			attnKeyLen:     int(c.Uint("attention.key_length", 256)),           //256 -- rope settings, hardcoded in model definition python
			eps:            c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
			ropeLocalBase:  c.Float("rope.local.freq_base", 10000.0),           // 10000 - hardcoded in python
			ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),        // 1e+06 - hardcoded in python
			ropeScale:      1,                                                  // 1 - default is 1, implied in python code
			// vocabSize:      vocabSize,                                          // 262144
			// attnValLen:     int(c.Uint("attention.value_length", 256)),         //256
			// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
			//       (8 instead of 1)
			// ropeScale:      c.Float("rope.scaling.factor", 1.0),
		},
	}
	if numBlocks == gemma27BLayerCount {
		m.largeModelScaling = true
	}

	return &m
}

type TextSelfAttention struct {
	Query     *nn.Linear  `gguf:"q_proj"`
	QueryNorm *nn.RMSNorm `gguf:"q_norm"`
	Key       *nn.Linear  `gguf:"k_proj"`
	KeyNorm   *nn.RMSNorm `gguf:"k_norm"`
	Value     *nn.Linear  `gguf:"v_proj"`
	Output    *nn.Linear  `gguf:"o_proj"`
}

func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
	B := hiddenState.Dim(0)
	L := hiddenState.Dim(1)
	ropeBase := opts.ropeLocalBase
	if (layer+1)%gemmaGlobalCacheCount == 0 {
		ropeBase = opts.ropeGlobalBase
	}

	q := sa.Query.Forward(ctx, hiddenState)
	k := sa.Key.Forward(ctx, hiddenState)
	v := sa.Value.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
	k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
	v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
	q = sa.QueryNorm.Forward(ctx, q, opts.eps)
	k = sa.KeyNorm.Forward(ctx, k, opts.eps)
	traditional := false
	q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
	k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))

	// TODO - this is wrong somehow so commenting out
	// if opts.largeModelScaling {
	// 	q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
	// } else {
	// 	q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
	// }

	scaleFactor := math.Pow(256, -0.5)

	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
	kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
	// ropeBase := m.TextConfig.ropeLocalBase
	// if (layer+1)%gemmaGlobalCacheCount == 0 {
	// 	ropeBase = m.TextConfig.ropeGlobalBase
	// }
	// 	q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
	panic("not yet implemented")
	// return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}

type TextMLP struct {
	Up   *nn.Linear `gguf:"up_proj"`
	Down *nn.Linear `gguf:"down_proj"`
	Gate *nn.Linear `gguf:"gate_proj"`
}

func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextLayer struct {
	AttentionNorm     *nn.RMSNorm        `gguf:"input_layernorm"`
	SelfAttention     *TextSelfAttention `gguf:"self_attn"`
	PostAttentionNorm *nn.RMSNorm        `gguf:"post_attention_layernorm"`
	MLPNorm           *nn.RMSNorm        `gguf:"pre_feedforward_layernorm"`
	MLP               *TextMLP           `gguf:"mlp"`
	PostMLPNorm       *nn.RMSNorm        `gguf:"post_feedforward_layernorm"`
}

func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
	residual := hiddenState
	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
	hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
		residual = residual.TakeAxes(ctx, outputs, 1)
	}

	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState
	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
	hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
	return hiddenState.Add(ctx, residual)
}

func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))

	// set image embeddings
	// var except []int
	// for _, image := range batch.Multimodal {
	// 	visionOutputs := image.Multimodal[0].Tensor
	// 	ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
	// 		[]int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
	// 		[]int{image.Index * hiddenState.Stride(1)}, 0)))

	// 	for i := range visionOutputs.Dim(1) {
	// 		except = append(except, image.Index+i)
	// 	}
	// }

	for i, layer := range m.Layers {
		// gemma alternates between the sliding window (local) and causal (global)
		// kv cache every 6 layers
		if cache != nil {
			// cacheType := cacheTypeSWA
			// if (i+1)%gemmaGlobalCacheCount == 0 {
			// 	cacheType = cacheTypeCausal
			// }
			cache.SetLayer(i)

			// TODO this needs to come back
			// wc := cache.(*kvcache.WrapperCache)
			// wc.SetLayerType(cacheType)

			// if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
			// 	causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
			// }
		}

		var offset int
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			offset = batch.Offset
			lastLayerOutputs = batch.Outputs
		}

		hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
	}
	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return hiddenState
}