model_text.go 8.86 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
package mllama

import (
	"math"
	"slices"

7
	"github.com/ollama/ollama/fs"
Jesse Gross's avatar
Jesse Gross committed
8
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
9
10
11
12
13
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
)

type TextSelfAttention struct {
Jesse Gross's avatar
Jesse Gross committed
14
15
16
17
18
	Query       *nn.Linear `gguf:"attn_q"`
	Key         *nn.Linear `gguf:"attn_k"`
	Value       *nn.Linear `gguf:"attn_v"`
	Output      *nn.Linear `gguf:"attn_output"`
	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
Michael Yang's avatar
Michael Yang committed
19
20
}

Jesse Gross's avatar
Jesse Gross committed
21
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
22
23
	batchSize := hiddenState.Dim(1)
	headDim := opts.hiddenSize / opts.numHeads
Patrick Devine's avatar
Patrick Devine committed
24
	ropeType := uint32(0)
Michael Yang's avatar
Michael Yang committed
25
26
27

	query := sa.Query.Forward(ctx, hiddenState)
	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
Patrick Devine's avatar
Patrick Devine committed
28
	query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
Michael Yang's avatar
Michael Yang committed
29
30
31

	key := sa.Key.Forward(ctx, hiddenState)
	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Jesse Gross's avatar
Jesse Gross committed
32
	key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
Michael Yang's avatar
Michael Yang committed
33
34
35
36

	value := sa.Value.Forward(ctx, hiddenState)
	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

37
	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
38
	attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
Michael Yang's avatar
Michael Yang committed
39
40
41
42
43
	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)

	return sa.Output.Forward(ctx, attention)
}

Jesse Gross's avatar
Jesse Gross committed
44
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
Patrick Devine's avatar
Patrick Devine committed
45
	// This will only get called for layers in the cache, which are just the self attention layers
Jesse Gross's avatar
Jesse Gross committed
46
	if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
Patrick Devine's avatar
Patrick Devine committed
47
		return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
Jesse Gross's avatar
Jesse Gross committed
48
49
50
	}

	return key, nil
Jesse Gross's avatar
Jesse Gross committed
51
52
}

Michael Yang's avatar
Michael Yang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
type TextMLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type TextSelfAttentionDecoderLayer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *TextSelfAttention

	MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
	MLP     *TextMLP
}

72
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
73
74
75
76
	residual := hiddenState

	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
77
78
79
80
81
82
83
84

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Michael Yang's avatar
Michael Yang committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

type TextCrossAttention struct {
	QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
	Query     *nn.Linear  `gguf:"cross_attn_q_proj"`
	KeyNorm   *nn.RMSNorm `gguf:"cross_attn_k_norm"`
	Key       *nn.Linear  `gguf:"cross_attn_k_proj"`
	Value     *nn.Linear  `gguf:"cross_attn_v_proj"`
	Output    *nn.Linear  `gguf:"cross_attn_o_proj"`
}

Jesse Gross's avatar
Jesse Gross committed
102
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
103
104
105
106
107
108
109
	batchSize := hiddenState.Dim(1)
	headDim := opts.hiddenSize / opts.numHeads

	query := ca.Query.Forward(ctx, hiddenState)
	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
	query = ca.QueryNorm.Forward(ctx, query, opts.eps)

110
	var key, value ml.Tensor
Jesse Gross's avatar
Jesse Gross committed
111
112
	if crossAttentionStates != nil {
		numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
Michael Yang's avatar
Michael Yang committed
113

Jesse Gross's avatar
Jesse Gross committed
114
115
116
		key = ca.Key.Forward(ctx, crossAttentionStates)
		key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
		key = ca.KeyNorm.Forward(ctx, key, opts.eps)
Michael Yang's avatar
Michael Yang committed
117

Jesse Gross's avatar
Jesse Gross committed
118
119
120
121
122
		value = ca.Value.Forward(ctx, crossAttentionStates)
		value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)

		cache.Put(ctx, key, value)
	}
Michael Yang's avatar
Michael Yang committed
123

124
	key, value, _ = cache.Get(ctx)
Michael Yang's avatar
Michael Yang committed
125

126
	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
127
128
129
130
131
132
133
134
135
136
137
138

	query = query.Permute(ctx, 0, 2, 1, 3)
	key = key.Permute(ctx, 0, 2, 1, 3)
	value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)

	kq := key.MulmatFullPrec(ctx, query)

	kq = kq.Scale(ctx, scaleFactor)
	kq = kq.Softmax(ctx)

	kqv := value.Mulmat(ctx, kq)
	attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
Michael Yang's avatar
Michael Yang committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)

	return ca.Output.Forward(ctx, attention)
}

type TextCrossAttentionDecoderLayer struct {
	AttentionNorm  *nn.RMSNorm `gguf:"attn_norm"`
	CrossAttention *TextCrossAttention
	AttentionGate  ml.Tensor `gguf:"cross_attn_attn_gate"`

	MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
	MLP     *TextMLP
	MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
}

154
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
	residual := hiddenState

	hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
	hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
	hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
	return hiddenState.Add(ctx, residual)
}

type TextDecoderLayer interface {
170
	Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
Michael Yang's avatar
Michael Yang committed
171
172
173
174
175
176
}

type TextDecoder struct {
	Layers []TextDecoderLayer
}

177
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
178
	for i, layer := range d.Layers {
Jesse Gross's avatar
Jesse Gross committed
179
180
181
182
183
184
185
186
187
		layerType := selfAttentionLayer
		if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
			layerType = crossAttentionLayer
		}

		cache.SetLayer(i)
		cache.SetLayerType(layerType)

		if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
188
189
190
191
192
193
			var lastLayerOutputs ml.Tensor
			if i == len(d.Layers)-1 {
				lastLayerOutputs = outputs
			}

			hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
Michael Yang's avatar
Michael Yang committed
194
195
196
197
198
199
200
		}
	}

	return hiddenState
}

type TextModelOptions struct {
201
	hiddenSize, numHeads, numKVHeads int
Michael Yang's avatar
Michael Yang committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
	eps, ropeBase, ropeScale         float32
	ropeDim                          uint32

	crossAttentionLayers []uint32
}

type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Transformer    *TextDecoder  `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output"`

	*TextModelOptions
}

217
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
218
	hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
219
	hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
Michael Yang's avatar
Michael Yang committed
220
221
222
223
	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState)
}

224
func newTextModel(c fs.Config) *TextModel {
Michael Yang's avatar
Michael Yang committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
	var decoderLayers []TextDecoderLayer
	for i := range c.Uint("block_count") {
		var textDecoderLayer TextDecoderLayer
		if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
			textDecoderLayer = &TextCrossAttentionDecoderLayer{}
		} else {
			textDecoderLayer = &TextSelfAttentionDecoderLayer{}
		}

		decoderLayers = append(decoderLayers, textDecoderLayer)
	}

	return &TextModel{
		Transformer: &TextDecoder{Layers: decoderLayers},
		TextModelOptions: &TextModelOptions{
240
241
242
			hiddenSize:           int(c.Uint("embedding_length")),
			numHeads:             int(c.Uint("attention.head_count")),
			numKVHeads:           int(c.Uint("attention.head_count_kv")),
Michael Yang's avatar
Michael Yang committed
243
244
245
246
247
248
249
250
			eps:                  c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:             c.Float("rope.freq_base"),
			ropeScale:            c.Float("rope.freq_scale", 1),
			ropeDim:              c.Uint("rope.dimension_count"),
			crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
		},
	}
}