model.go 5.33 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
package llama

import (
	"math"

Jesse Gross's avatar
Jesse Gross committed
6
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
7
8
9
10
11
12
13
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
)

type Options struct {
	RopeFactors                      ml.Tensor `gguf:"rope_freqs.weight"`
14
	hiddenSize, numHeads, numKVHeads int
Michael Yang's avatar
Michael Yang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
	eps, ropeBase, ropeScale         float32
	ropeDim                          uint32
}

type Model struct {
	model.Base
	model.BytePairEncoding

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*Options
}

func New(c ml.Config) (model.Model, error) {
Jesse Gross's avatar
Jesse Gross committed
32
	m := Model{
Michael Yang's avatar
Michael Yang committed
33
34
35
36
37
38
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
39
40
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
Michael Yang's avatar
Michael Yang committed
41
42
43
44
			},
		),
		Layers: make([]Layer, c.Uint("block_count")),
		Options: &Options{
45
46
47
			hiddenSize: int(c.Uint("embedding_length")),
			numHeads:   int(c.Uint("attention.head_count")),
			numKVHeads: int(c.Uint("attention.head_count_kv")),
Michael Yang's avatar
Michael Yang committed
48
49
50
51
52
			eps:        c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:   c.Float("rope.freq_base"),
			ropeScale:  c.Float("rope.freq_scale", 1),
			ropeDim:    c.Uint("rope.dimension_count"),
		},
Jesse Gross's avatar
Jesse Gross committed
53
54
55
56
57
	}

	m.Cache = kvcache.NewCausalCache(m.Shift)

	return &m, nil
Michael Yang's avatar
Michael Yang committed
58
59
60
61
62
63
64
65
66
}

type SelfAttention struct {
	Query  *nn.Linear `gguf:"attn_q"`
	Key    *nn.Linear `gguf:"attn_k"`
	Value  *nn.Linear `gguf:"attn_v"`
	Output *nn.Linear `gguf:"attn_output"`
}

Jesse Gross's avatar
Jesse Gross committed
67
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
	batchSize := hiddenState.Dim(1)
	headDim := opts.hiddenSize / opts.numHeads

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
	q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
	k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

Jesse Gross's avatar
Jesse Gross committed
82
83
	cache.Put(ctx, k, v)
	k, v, mask := cache.Get(ctx)
Michael Yang's avatar
Michael Yang committed
84
85
86
87
88

	q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)

89
	kq := k.MulmatFullPrec(ctx, q)
Michael Yang's avatar
Michael Yang committed
90
	kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
Jesse Gross's avatar
Jesse Gross committed
91
	kq = kq.Add(ctx, mask)
Michael Yang's avatar
Michael Yang committed
92
93
94
95
96
97
98
99
100
	kq = kq.Softmax(ctx)

	kqv := v.Mulmat(ctx, kq)
	kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

Jesse Gross's avatar
Jesse Gross committed
101
102
103
104
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
	return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
}

Michael Yang's avatar
Michael Yang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
type MLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type Layer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *SelfAttention
	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
	MLP           *MLP
}

123
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
124
125
126
127
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
128
129
130
131
132
133
134
135

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Michael Yang's avatar
Michael Yang committed
136
137
138
139
140
141
142
143
144
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
Jesse Gross's avatar
Jesse Gross committed
145
	inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
Michael Yang's avatar
Michael Yang committed
146
147
148
149
	if err != nil {
		return nil, err
	}

Jesse Gross's avatar
Jesse Gross committed
150
	positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
Michael Yang's avatar
Michael Yang committed
151
152
153
154
	if err != nil {
		return nil, err
	}

155
156
157
158
159
	outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
	if err != nil {
		return nil, err
	}

Michael Yang's avatar
Michael Yang committed
160
161
162
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)

	for i, layer := range m.Layers {
Jesse Gross's avatar
Jesse Gross committed
163
		m.Cache.SetLayer(i)
Michael Yang's avatar
Michael Yang committed
164

165
166
167
168
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}
Michael Yang's avatar
Michael Yang committed
169

170
		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
Michael Yang's avatar
Michael Yang committed
171
172
	}

173
174
	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState), nil
Michael Yang's avatar
Michael Yang committed
175
176
177
178
179
}

func init() {
	model.Register("llama", New)
}