model_text.go 9.06 KB
Newer Older
Michael Yang's avatar
llama4  
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
package llama4

import (
	"cmp"
	"math"

	"github.com/ollama/ollama/fs"
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
12
	"github.com/ollama/ollama/ml/nn/fast"
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
llama4  
Michael Yang committed
13
14
15
16
17
18
19
20
21
22
23
	"github.com/ollama/ollama/model/input"
)

type TextAttention struct {
	Query       *nn.Linear `gguf:"attn_q"`
	Key         *nn.Linear `gguf:"attn_k"`
	Value       *nn.Linear `gguf:"attn_v"`
	Output      *nn.Linear `gguf:"attn_output"`
	RopeFactors ml.Tensor  `gguf:"rope_factors"`
}

Michael Yang's avatar
Michael Yang committed
24
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
25
26
27
28
29
30
31
32
33
34
35
	batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)

	query := sa.Query.Forward(ctx, hiddenStates)
	key := sa.Key.Forward(ctx, hiddenStates)
	value := sa.Value.Forward(ctx, hiddenStates)

	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

	if useRope {
36
37
		query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
		key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
Michael Yang's avatar
Michael Yang committed
38
	}
Michael Yang's avatar
llama4  
Michael Yang committed
39

Michael Yang's avatar
Michael Yang committed
40
41
42
43
44
45
46
	if opts.useQKNorm {
		query = query.RMSNorm(ctx, nil, opts.eps)
		key = key.RMSNorm(ctx, nil, opts.eps)
	}

	if attentionScales != nil && !useRope {
		query = query.Mul(ctx, attentionScales)
Michael Yang's avatar
llama4  
Michael Yang committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
	}

	attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
	attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
	return sa.Output.Forward(ctx, attention)
}

type TextMLP struct {
	Gate *nn.Linear `gguf:"ffn_gate"`
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
}

func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
	hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
	return mlp.Down.Forward(ctx, hiddenStates)
}

type TextExperts struct {
	Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
	Up   ml.Tensor `gguf:"ffn_up_exps.weight"`
	Down ml.Tensor `gguf:"ffn_down_exps.weight"`
}

func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
	experts := routerLogits.TopK(ctx, opts.numExpertsUsed)
	scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts)

	hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
	hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
	hiddenStates = hiddenStates.Mul(ctx, scores)

	upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
	gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
	downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)

	nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
	for i := 1; i < opts.numExpertsUsed; i++ {
85
		nextStates = nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
Michael Yang's avatar
llama4  
Michael Yang committed
86
87
88
89
90
	}

	return nextStates
}

Michael Yang's avatar
Michael Yang committed
91
// TextSharedExpert is TextMLP with different tensor names
Michael Yang's avatar
llama4  
Michael Yang committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
type TextSharedExpert struct {
	Gate *nn.Linear `gguf:"ffn_gate_shexp"`
	Up   *nn.Linear `gguf:"ffn_up_shexp"`
	Down *nn.Linear `gguf:"ffn_down_shexp"`
}

func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
	hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
	return mlp.Down.Forward(ctx, hiddenStates)
}

type TextMOE struct {
	Router       *nn.Linear `gguf:"ffn_gate_inp"`
	Experts      *TextExperts
	SharedExpert *TextSharedExpert
}

func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
	hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
	hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
	routerLogits := moe.Router.Forward(ctx, hiddenStates)

	sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts)
	routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts)
	return sharedStates.Add(ctx, routedStates)
}

type TextFeedForward interface {
	Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor
}

type TextLayer struct {
	AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
	Attention     *TextAttention

	FFNNorm     *nn.LayerNorm `gguf:"ffn_norm"`
	FeedForward TextFeedForward
}

Michael Yang's avatar
Michael Yang committed
131
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
Michael Yang's avatar
llama4  
Michael Yang committed
132
133
134
135
	residual := hiddenStates

	// self attention
	hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
Michael Yang's avatar
Michael Yang committed
136
	hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, useRope, opts)
Michael Yang's avatar
llama4  
Michael Yang committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

	if outputs != nil {
		hiddenStates = hiddenStates.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

	hiddenStates = hiddenStates.Add(ctx, residual)
	residual = hiddenStates

	hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
	hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts)

	return residual.Add(ctx, hiddenStates)
}

type TextOptions struct {
	hiddenSize                    int
	numHeads, numKVHeads, headDim int
	numExperts, numExpertsUsed    int
	ropeDim                       int
	ropeBase, ropeScale           float32
	eps                           float32
	interleaveLayerStep           int
Michael Yang's avatar
Michael Yang committed
160
	noRopeInterval                int
Michael Yang's avatar
llama4  
Michael Yang committed
161
	useQKNorm                     bool
Michael Yang's avatar
Michael Yang committed
162
163
164
	attentionTemperatureTuning    bool
	attentionScale                float64
	attentionFloorScale           float64
Michael Yang's avatar
llama4  
Michael Yang committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
}

type TextModel struct {
	Layers []TextLayer `gguf:"blk"`

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	OutputNorm     *nn.LayerNorm `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*TextOptions
}

func newTextModel(c fs.Config) *TextModel {
	layers := make([]TextLayer, c.Uint("block_count"))
	interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1)
	for i := range layers {
		if (i+1)%int(interleaveLayerStep) == 0 {
			layers[i] = TextLayer{FeedForward: &TextMOE{}}
		} else {
			layers[i] = TextLayer{FeedForward: &TextMLP{}}
		}
	}

	return &TextModel{
		Layers: layers,
		TextOptions: &TextOptions{
Michael Yang's avatar
Michael Yang committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
			hiddenSize:                 int(c.Uint("embedding_length")),
			numHeads:                   int(c.Uint("attention.head_count")),
			numKVHeads:                 int(c.Uint("attention.head_count_kv")),
			headDim:                    int(c.Uint("attention.head_dim", 128)),
			numExperts:                 int(c.Uint("expert_count")),
			numExpertsUsed:             int(c.Uint("expert_used_count")),
			ropeDim:                    int(c.Uint("rope.dimension_count")),
			ropeBase:                   c.Float("rope.freq_base"),
			ropeScale:                  c.Float("rope.freq_scale", 1),
			eps:                        c.Float("attention.layer_norm_rms_epsilon"),
			interleaveLayerStep:        int(c.Uint("interleave_moe_layer_step", 1)),
			noRopeInterval:             int(c.Uint("no_rope_interval", 4)),
			useQKNorm:                  c.Bool("use_qk_norm", true),
			attentionTemperatureTuning: c.Bool("attention.temperature_tuning", true),
			attentionScale:             float64(c.Float("attention.scale", 0.1)),
			attentionFloorScale:        float64(c.Float("attention.floor_scale", 8192)),
Michael Yang's avatar
llama4  
Michael Yang committed
207
208
209
210
211
		},
	}
}

func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
212
213
214
	hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)

	for _, mi := range batch.Multimodal {
215
		img := mi.Multimodal[0].Tensor
Michael Yang's avatar
Michael Yang committed
216
217
		ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
	}
Michael Yang's avatar
llama4  
Michael Yang committed
218

Michael Yang's avatar
Michael Yang committed
219
220
221
222
223
224
225
	var attentionScales ml.Tensor
	if m.attentionTemperatureTuning {
		scales := make([]float32, len(batch.Positions))
		for i, p := range batch.Positions {
			scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
		}

226
		attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
Michael Yang's avatar
Michael Yang committed
227
228
	}

Michael Yang's avatar
llama4  
Michael Yang committed
229
230
231
232
	for i, layer := range m.Layers {
		cache.SetLayer(i)
		wc := cache.(*kvcache.WrapperCache)
		wc.SetLayerType(1)
Michael Yang's avatar
Michael Yang committed
233
		useChunkedAttention := (i+1)%m.noRopeInterval != 0
Michael Yang's avatar
llama4  
Michael Yang committed
234
235
236
237
238
239
240
241
242
		if useChunkedAttention {
			wc.SetLayerType(0)
		}

		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

Michael Yang's avatar
Michael Yang committed
243
		hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions)
Michael Yang's avatar
llama4  
Michael Yang committed
244
245
246
247
248
249
250
	}

	hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
	return m.Output.Forward(ctx, hiddenStates)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
251
	return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
Michael Yang's avatar
llama4  
Michael Yang committed
252
}