model_text.go 8.83 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
package mllama

import (
	"math"
	"slices"

Jesse Gross's avatar
Jesse Gross committed
7
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
8
9
10
11
12
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
)

type TextSelfAttention struct {
Jesse Gross's avatar
Jesse Gross committed
13
14
15
16
17
	Query       *nn.Linear `gguf:"attn_q"`
	Key         *nn.Linear `gguf:"attn_k"`
	Value       *nn.Linear `gguf:"attn_v"`
	Output      *nn.Linear `gguf:"attn_output"`
	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
Michael Yang's avatar
Michael Yang committed
18
19
}

Jesse Gross's avatar
Jesse Gross committed
20
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
21
22
	batchSize := hiddenState.Dim(1)
	headDim := opts.hiddenSize / opts.numHeads
Patrick Devine's avatar
Patrick Devine committed
23
	ropeType := uint32(0)
Michael Yang's avatar
Michael Yang committed
24
25
26

	query := sa.Query.Forward(ctx, hiddenState)
	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
Patrick Devine's avatar
Patrick Devine committed
27
	query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
Michael Yang's avatar
Michael Yang committed
28
29
30

	key := sa.Key.Forward(ctx, hiddenState)
	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Jesse Gross's avatar
Jesse Gross committed
31
	key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
Michael Yang's avatar
Michael Yang committed
32
33
34
35

	value := sa.Value.Forward(ctx, hiddenState)
	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

36
	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
37
	attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
Michael Yang's avatar
Michael Yang committed
38
39
40
41
42
	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)

	return sa.Output.Forward(ctx, attention)
}

Jesse Gross's avatar
Jesse Gross committed
43
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
Patrick Devine's avatar
Patrick Devine committed
44
	// This will only get called for layers in the cache, which are just the self attention layers
Jesse Gross's avatar
Jesse Gross committed
45
	if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
Patrick Devine's avatar
Patrick Devine committed
46
		return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
Jesse Gross's avatar
Jesse Gross committed
47
48
49
	}

	return key, nil
Jesse Gross's avatar
Jesse Gross committed
50
51
}

Michael Yang's avatar
Michael Yang committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextSelfAttentionDecoderLayer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *TextSelfAttention

	MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
	MLP     *TextMLP
}

71
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
72
73
74
75
	residual := hiddenState

	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
76
77
78
79
80
81
82
83

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Michael Yang's avatar
Michael Yang committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

type TextCrossAttention struct {
	QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
	Query     *nn.Linear  `gguf:"cross_attn_q_proj"`
	KeyNorm   *nn.RMSNorm `gguf:"cross_attn_k_norm"`
	Key       *nn.Linear  `gguf:"cross_attn_k_proj"`
	Value     *nn.Linear  `gguf:"cross_attn_v_proj"`
	Output    *nn.Linear  `gguf:"cross_attn_o_proj"`
}

Jesse Gross's avatar
Jesse Gross committed
101
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
102
103
104
105
106
107
108
	batchSize := hiddenState.Dim(1)
	headDim := opts.hiddenSize / opts.numHeads

	query := ca.Query.Forward(ctx, hiddenState)
	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
	query = ca.QueryNorm.Forward(ctx, query, opts.eps)

109
	var key, value ml.Tensor
Jesse Gross's avatar
Jesse Gross committed
110
111
	if crossAttentionStates != nil {
		numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
Michael Yang's avatar
Michael Yang committed
112

Jesse Gross's avatar
Jesse Gross committed
113
114
115
		key = ca.Key.Forward(ctx, crossAttentionStates)
		key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
		key = ca.KeyNorm.Forward(ctx, key, opts.eps)
Michael Yang's avatar
Michael Yang committed
116

Jesse Gross's avatar
Jesse Gross committed
117
118
119
120
121
		value = ca.Value.Forward(ctx, crossAttentionStates)
		value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)

		cache.Put(ctx, key, value)
	}
Michael Yang's avatar
Michael Yang committed
122

123
	key, value, _ = cache.Get(ctx)
Michael Yang's avatar
Michael Yang committed
124

125
	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
126
127
128
129
130
131
132
133
134
135
136
137

	query = query.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)
	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)

	kq := key.MulmatFullPrec(ctx, query)

	kq = kq.Scale(ctx, scaleFactor)
	kq = kq.Softmax(ctx)

	kqv := value.Mulmat(ctx, kq)
	attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
Michael Yang's avatar
Michael Yang committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)

	return ca.Output.Forward(ctx, attention)
}

type TextCrossAttentionDecoderLayer struct {
	AttentionNorm  *nn.RMSNorm `gguf:"attn_norm"`
	CrossAttention *TextCrossAttention
	AttentionGate  ml.Tensor `gguf:"cross_attn_attn_gate"`

	MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
	MLP     *TextMLP
	MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
}

153
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
	residual := hiddenState

	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
	hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
	return hiddenState.Add(ctx, residual)
}

type TextDecoderLayer interface {
169
	Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
Michael Yang's avatar
Michael Yang committed
170
171
172
173
174
175
}

type TextDecoder struct {
	Layers []TextDecoderLayer
}

176
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
177
	for i, layer := range d.Layers {
Jesse Gross's avatar
Jesse Gross committed
178
179
180
181
182
183
184
185
186
		layerType := selfAttentionLayer
		if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
			layerType = crossAttentionLayer
		}

		cache.SetLayer(i)
		cache.SetLayerType(layerType)

		if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
187
188
189
190
191
192
			var lastLayerOutputs ml.Tensor
			if i == len(d.Layers)-1 {
				lastLayerOutputs = outputs
			}

			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
Michael Yang's avatar
Michael Yang committed
193
194
195
196
197
198
199
		}
	}

	return hiddenState
}

type TextModelOptions struct {
200
	hiddenSize, numHeads, numKVHeads int
Michael Yang's avatar
Michael Yang committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
	eps, ropeBase, ropeScale         float32
	ropeDim                          uint32

	crossAttentionLayers []uint32
}

type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Transformer    *TextDecoder  `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output"`

	*TextModelOptions
}

216
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
217
	hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
218
	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
Michael Yang's avatar
Michael Yang committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState)
}

func newTextModel(c ml.Config) *TextModel {
	var decoderLayers []TextDecoderLayer
	for i := range c.Uint("block_count") {
		var textDecoderLayer TextDecoderLayer
		if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
			textDecoderLayer = &TextCrossAttentionDecoderLayer{}
		} else {
			textDecoderLayer = &TextSelfAttentionDecoderLayer{}
		}

		decoderLayers = append(decoderLayers, textDecoderLayer)
	}

	return &TextModel{
		Transformer: &TextDecoder{Layers: decoderLayers},
		TextModelOptions: &TextModelOptions{
239
240
241
			hiddenSize:           int(c.Uint("embedding_length")),
			numHeads:             int(c.Uint("attention.head_count")),
			numKVHeads:           int(c.Uint("attention.head_count_kv")),
Michael Yang's avatar
Michael Yang committed
242
243
244
245
246
247
248
249
			eps:                  c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:             c.Float("rope.freq_base"),
			ropeScale:            c.Float("rope.freq_scale", 1),
			ropeDim:              c.Uint("rope.dimension_count"),
			crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
		},
	}
}