model_text.go 6.46 KB
Newer Older
1
2
3
package mistral3

import (
4
	"cmp"
5
6
7
8
9
10
	"math"

	"github.com/ollama/ollama/fs"
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
	"github.com/ollama/ollama/ml/nn/rope"
12
13
14
15
	"github.com/ollama/ollama/model/input"
)

type TextOptions struct {
16
17
18
	hiddenSize, numHeads, numKVHeads int
	headDim, ropeDim                 int
	eps, ropeBase, ropeScale         float32
19
20
	ropeOrigPosEmbeddings            int
	ropeScalingBeta                  float32
21
22
23
24
25
26
	ropeType                         string
	ropeExtrapolation                float32
	ropeBetaFast                     float32
	ropeBetaSlow                     float32
	ropeMscale                       float32
	ropeMscaleAllDim                 float32
27
28
}

Michael Yang's avatar
Michael Yang committed
29
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
30
31
32
	var ropeOpts []func(*rope.Options)
	if o.ropeType == "yarn" {
		if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
33
			ropeOpts = append(ropeOpts, rope.WithAttentionFactor(1.0/float32(0.1*math.Log(float64(o.ropeScale))+1.0)))
34
35
36
37
38
39
40
41
42
43
44
		}

		ropeOpts = append(ropeOpts,
			rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
			rope.WithExtrapolationFactor(o.ropeExtrapolation),
			rope.WithBetaFast(o.ropeBetaFast),
			rope.WithBetaSlow(o.ropeBetaSlow),
		)
	}

	return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...)
Michael Yang's avatar
Michael Yang committed
45
46
}

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*TextOptions
}

type SelfAttention struct {
	Query  *nn.Linear `gguf:"attn_q"`
	Key    *nn.Linear `gguf:"attn_k"`
	Value  *nn.Linear `gguf:"attn_v"`
	Output *nn.Linear `gguf:"attn_output"`
}

63
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
64
	batchSize := hiddenState.Dim(1)
65
	headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
66
67
68

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
69
	q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
70
71
72

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
73
	k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
74
75
76
77

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

78
79
80
81
	if opts.ropeOrigPosEmbeddings > 0 {
		q = q.Mul(ctx, positionsScale)
	}

82
83
84
85
86
87
	kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
	kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
Michael Yang's avatar
Michael Yang committed
88
	return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
89
90
91
92
93
94
95
96
97
}

type MLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
98
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
99
100
101
102
103
104
105
106
107
108
	return mlp.Down.Forward(ctx, hiddenState)
}

type Layer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *SelfAttention
	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
	MLP           *MLP
}

109
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
110
111
112
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
113
	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

130
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
131
132
133
134
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)

	// image embeddings
	for _, image := range batch.Multimodal {
135
		imageFeature := image.Multimodal[0].Tensor
136
137
138
139
140
141
142
143
144
145
146
		ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
	}

	for i, layer := range m.Layers {
		cache.SetLayer(i)

		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

147
		hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
148
149
150
151
152
153
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState)
}

154
155
156
157
158
159
160
161
162
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
	posScale := make([]float32, len(positions))
	for n, pos := range positions {
		interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
		posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
	}
	return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
}

163
164
func newTextModel(c fs.Config) *TextModel {
	return &TextModel{
165
166
		Layers: make([]Layer, c.Uint("block_count")),
		TextOptions: &TextOptions{
167
168
169
170
171
172
173
			hiddenSize:            int(c.Uint("embedding_length")),
			numHeads:              int(c.Uint("attention.head_count")),
			numKVHeads:            int(c.Uint("attention.head_count_kv")),
			headDim:               int(c.Uint("attention.key_length")),
			ropeDim:               int(c.Uint("rope.dimension_count")),
			eps:                   c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:              c.Float("rope.freq_base"),
174
			ropeScale:             c.Float("rope.scaling.factor", 1.0),
175
			ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
176
177
178
179
180
181
182
			ropeScalingBeta:       c.Float("rope.scaling_beta", 0.1),
			ropeBetaFast:          c.Float("rope.scaling.beta_fast", 32.0),
			ropeBetaSlow:          c.Float("rope.scaling.beta_slow", 1.0),
			ropeType:              c.String("rope.scaling.type"),
			ropeMscale:            c.Float("rope.scaling.mscale"),
			ropeMscaleAllDim:      c.Float("rope.scaling.mscale_all_dim"),
			ropeExtrapolation:     c.Float("rope.scaling.extrapolation_factor", 1),
183
184
185
		},
	}
}