model_text.go 6.77 KB
Newer Older
1
2
3
package mistral3

import (
4
	"cmp"
5
6
7
8
9
10
	"math"

	"github.com/ollama/ollama/fs"
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
	"github.com/ollama/ollama/ml/nn/rope"
12
13
14
15
	"github.com/ollama/ollama/model/input"
)

type TextOptions struct {
16
17
18
	hiddenSize, numHeads, numKVHeads int
	headDim, ropeDim                 int
	eps, ropeBase, ropeScale         float32
19
20
	ropeOrigPosEmbeddings            int
	ropeScalingBeta                  float32
21
22
23
24
25
26
	ropeType                         string
	ropeExtrapolation                float32
	ropeBetaFast                     float32
	ropeBetaSlow                     float32
	ropeMscale                       float32
	ropeMscaleAllDim                 float32
27
28
}

Michael Yang's avatar
Michael Yang committed
29
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
	var ropeOpts []func(*rope.Options)
	if o.ropeType == "yarn" {
		getMscale := func(scale, mscale float64) float64 {
			if scale <= 1.0 {
				return 1.0
			}
			return 0.1*mscale*math.Log(scale) + 1.0
		}

		var attnFactor float32
		if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
			attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim)))
		} else {
			attnFactor = float32(getMscale(float64(o.ropeScale), 1))
		}

		ropeOpts = append(ropeOpts,
			rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
			rope.WithExtrapolationFactor(o.ropeExtrapolation),
			rope.WithAttentionFactor(attnFactor),
			rope.WithBetaFast(o.ropeBetaFast),
			rope.WithBetaSlow(o.ropeBetaSlow),
		)
	}

	return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...)
Michael Yang's avatar
Michael Yang committed
56
57
}

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*TextOptions
}

type SelfAttention struct {
	Query  *nn.Linear `gguf:"attn_q"`
	Key    *nn.Linear `gguf:"attn_k"`
	Value  *nn.Linear `gguf:"attn_v"`
	Output *nn.Linear `gguf:"attn_output"`
}

74
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
75
	batchSize := hiddenState.Dim(1)
76
	headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
77
78
79

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
80
	q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
81
82
83

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
84
	k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
85
86
87
88

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

89
90
91
92
	if opts.ropeOrigPosEmbeddings > 0 {
		q = q.Mul(ctx, positionsScale)
	}

93
94
95
96
97
98
	kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
	kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
Michael Yang's avatar
Michael Yang committed
99
	return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
100
101
102
103
104
105
106
107
108
}

type MLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
109
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
110
111
112
113
114
115
116
117
118
119
	return mlp.Down.Forward(ctx, hiddenState)
}

type Layer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *SelfAttention
	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
	MLP           *MLP
}

120
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
121
122
123
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
124
	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

141
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
142
143
144
145
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)

	// image embeddings
	for _, image := range batch.Multimodal {
146
		imageFeature := image.Multimodal[0].Tensor
147
148
149
150
151
152
153
154
155
156
157
		ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
	}

	for i, layer := range m.Layers {
		cache.SetLayer(i)

		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

158
		hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
159
160
161
162
163
164
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState)
}

165
166
167
168
169
170
171
172
173
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
	posScale := make([]float32, len(positions))
	for n, pos := range positions {
		interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
		posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
	}
	return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
}

174
175
func newTextModel(c fs.Config) *TextModel {
	return &TextModel{
176
177
		Layers: make([]Layer, c.Uint("block_count")),
		TextOptions: &TextOptions{
178
179
180
181
182
183
184
			hiddenSize:            int(c.Uint("embedding_length")),
			numHeads:              int(c.Uint("attention.head_count")),
			numKVHeads:            int(c.Uint("attention.head_count_kv")),
			headDim:               int(c.Uint("attention.key_length")),
			ropeDim:               int(c.Uint("rope.dimension_count")),
			eps:                   c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:              c.Float("rope.freq_base"),
185
			ropeScale:             c.Float("rope.scaling.factor", 1.0),
186
			ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
187
188
189
190
191
192
193
			ropeScalingBeta:       c.Float("rope.scaling_beta", 0.1),
			ropeBetaFast:          c.Float("rope.scaling.beta_fast", 32.0),
			ropeBetaSlow:          c.Float("rope.scaling.beta_slow", 1.0),
			ropeType:              c.String("rope.scaling.type"),
			ropeMscale:            c.Float("rope.scaling.mscale"),
			ropeMscaleAllDim:      c.Float("rope.scaling.mscale_all_dim"),
			ropeExtrapolation:     c.Float("rope.scaling.extrapolation_factor", 1),
194
195
196
		},
	}
}