model_text.go 5.25 KB
Newer Older
1
2
3
package mistral3

import (
4
	"cmp"
5
6
7
8
9
10
	"math"

	"github.com/ollama/ollama/fs"
	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
	"github.com/ollama/ollama/ml/nn/fast"
12
13
14
15
	"github.com/ollama/ollama/model/input"
)

type TextOptions struct {
16
17
18
	hiddenSize, numHeads, numKVHeads int
	headDim, ropeDim                 int
	eps, ropeBase, ropeScale         float32
19
20
	ropeOrigPosEmbeddings            int
	ropeScalingBeta                  float32
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
}

type TextModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*TextOptions
}

type SelfAttention struct {
	Query  *nn.Linear `gguf:"attn_q"`
	Key    *nn.Linear `gguf:"attn_k"`
	Value  *nn.Linear `gguf:"attn_v"`
	Output *nn.Linear `gguf:"attn_output"`
}

39
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
40
	batchSize := hiddenState.Dim(1)
41
	headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
42
43
44

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
45
	q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
46
47
48

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
49
	k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
50
51
52
53

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

54
55
56
57
	if opts.ropeOrigPosEmbeddings > 0 {
		q = q.Mul(ctx, positionsScale)
	}

58
59
60
61
62
63
	kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
	kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
	return sa.Output.Forward(ctx, kqv)
}

func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
64
	return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
65
66
67
68
69
70
71
72
73
}

type MLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
74
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
75
76
77
78
79
80
81
82
83
84
	return mlp.Down.Forward(ctx, hiddenState)
}

type Layer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *SelfAttention
	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
	MLP           *MLP
}

85
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
86
87
88
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
89
	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

106
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
107
108
109
110
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)

	// image embeddings
	for _, image := range batch.Multimodal {
111
		imageFeature := image.Multimodal[0].Tensor
112
113
114
115
116
117
118
119
120
121
122
		ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
	}

	for i, layer := range m.Layers {
		cache.SetLayer(i)

		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}

123
		hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
124
125
126
127
128
129
	}

	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState)
}

130
131
132
133
134
135
136
137
138
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
	posScale := make([]float32, len(positions))
	for n, pos := range positions {
		interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
		posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
	}
	return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
}

139
140
func newTextModel(c fs.Config) *TextModel {
	return &TextModel{
141
142
		Layers: make([]Layer, c.Uint("block_count")),
		TextOptions: &TextOptions{
143
144
145
146
147
148
149
150
151
152
			hiddenSize:            int(c.Uint("embedding_length")),
			numHeads:              int(c.Uint("attention.head_count")),
			numKVHeads:            int(c.Uint("attention.head_count_kv")),
			headDim:               int(c.Uint("attention.key_length")),
			ropeDim:               int(c.Uint("rope.dimension_count")),
			eps:                   c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:              c.Float("rope.freq_base"),
			ropeScale:             c.Float("rope.scaling.factor", 1),
			ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
			ropeScalingBeta:       c.Float("rope.scaling_beta"),
153
154
155
		},
	}
}