model_text.go 4.43 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package deepseekocr

import (
	"math"

	"github.com/ollama/ollama/kvcache"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/ml/nn/fast"
	"github.com/ollama/ollama/ml/nn/rope"
)

type textModel struct {
	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Blocks         []textBlock   `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output"`

	Options textOptions
}

func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
	return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
}

type textOptions struct {
	hiddenSize,
	numHeads,
	numKVHeads,
	numExperts,
	numExpertsUsed int
	ropeBase,
	ropeScale,
	eps float32
}

func (o textOptions) headDim() int {
	return o.hiddenSize / o.numHeads
}

func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
	return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
}

type textBlock struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	Attention     *textAttention
	MLPNNorm      *nn.RMSNorm `gguf:"ffn_norm"`
	FeedForward   textFeedForward
}

func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
	residual := hiddenStates
	hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
	hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
	if outputs != nil {
		hiddenStates = hiddenStates.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

	hiddenStates = hiddenStates.Add(ctx, residual)

	residual = hiddenStates
	hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps)
	hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
	return hiddenStates.Add(ctx, residual)
}

type textAttention struct {
	Query  *nn.Linear `gguf:"attn_q"`
	Key    *nn.Linear `gguf:"attn_k"`
	Value  *nn.Linear `gguf:"attn_v"`
	Output *nn.Linear `gguf:"attn_output"`
}

func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
	query := m.Query.Forward(ctx, hiddenStates)
	query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1)

	key := m.Key.Forward(ctx, hiddenStates)
	key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)

	value := m.Value.Forward(ctx, hiddenStates)
	value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)

	query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
	key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)

	attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
	attention = attention.Reshape(ctx, -1, attention.Dim(2))
	return m.Output.Forward(ctx, attention)
}

type textFeedForward interface {
	Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor
}

type textMoe struct {
	Router        *nn.Linear      `gguf:"ffn_gate_inp"`
	Gate          *nn.LinearBatch `gguf:"ffn_gate_exps"`
	Up            *nn.LinearBatch `gguf:"ffn_up_exps"`
	Down          *nn.LinearBatch `gguf:"ffn_down_exps"`
	SharedExperts *textMLP        `gguf:",suf:_shexp"`
}

func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor {
	scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx)
	indices := scores.TopK(ctx, opts.numExpertsUsed)
	weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices)

	experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
	experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices))
	experts = m.Down.Forward(ctx, experts, indices)
	experts = experts.Mul(ctx, weights)

	expert := func(i int) ml.Tensor {
		return experts.View(
			ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2),
		)
	}

	routedStates := expert(0)
	for i := 1; i < opts.numExpertsUsed; i++ {
		routedStates = routedStates.Add(ctx, expert(i))
	}

	sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts)
	return routedStates.Add(ctx, sharedStates)
}

type textMLP struct {
	Gate *nn.Linear `gguf:"ffn_gate"`
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
}

func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor {
	hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
	return m.Down.Forward(ctx, hiddenStates)
}