model.go 5.43 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package llama

import (
4
	"cmp"
Michael Yang's avatar
Michael Yang committed
5
6
	"math"

7
	"github.com/ollama/ollama/fs"
Jesse Gross's avatar
Jesse Gross committed
8
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
9
10
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
12
	"github.com/ollama/ollama/ml/nn/fast"
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
13
	"github.com/ollama/ollama/model"
14
	"github.com/ollama/ollama/model/input"
Michael Yang's avatar
Michael Yang committed
15
16
17
)

type Options struct {
18
19
20
	hiddenSize, numHeads, numKVHeads int
	headDim, ropeDim                 int
	eps, ropeBase, ropeScale         float32
Michael Yang's avatar
Michael Yang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
}

type Model struct {
	model.Base
	model.BytePairEncoding

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*Options
}

35
func New(c fs.Config) (model.Model, error) {
Jesse Gross's avatar
Jesse Gross committed
36
	m := Model{
Michael Yang's avatar
Michael Yang committed
37
38
39
40
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
Michael Yang's avatar
Michael Yang committed
41
				Types:  c.Ints("tokenizer.ggml.token_type"),
Michael Yang's avatar
Michael Yang committed
42
				Merges: c.Strings("tokenizer.ggml.merges"),
43
				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
44
				BOS:    []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
45
				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
46
47
48
49
				EOS: append(
					[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
					c.Ints("tokenizer.ggml.eos_token_ids")...,
				),
Michael Yang's avatar
Michael Yang committed
50
51
52
53
			},
		),
		Layers: make([]Layer, c.Uint("block_count")),
		Options: &Options{
54
55
56
			hiddenSize: int(c.Uint("embedding_length")),
			numHeads:   int(c.Uint("attention.head_count")),
			numKVHeads: int(c.Uint("attention.head_count_kv")),
57
			headDim:    int(c.Uint("attention.key_length")),
58
			ropeDim:    int(c.Uint("rope.dimension_count")),
Michael Yang's avatar
Michael Yang committed
59
60
61
62
			eps:        c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:   c.Float("rope.freq_base"),
			ropeScale:  c.Float("rope.freq_scale", 1),
		},
Jesse Gross's avatar
Jesse Gross committed
63
64
65
66
67
	}

	m.Cache = kvcache.NewCausalCache(m.Shift)

	return &m, nil
Michael Yang's avatar
Michael Yang committed
68
69
70
}

type SelfAttention struct {
71
72
73
74
75
	Query       *nn.Linear `gguf:"attn_q"`
	Key         *nn.Linear `gguf:"attn_k"`
	Value       *nn.Linear `gguf:"attn_v"`
	Output      *nn.Linear `gguf:"attn_output"`
	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
Michael Yang's avatar
Michael Yang committed
76
77
}

Michael Yang's avatar
Michael Yang committed
78
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
79
	batchSize := hiddenState.Dim(1)
80
	headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
Michael Yang's avatar
Michael Yang committed
81
	ropeDim := cmp.Or(opts.ropeDim, headDim)
Michael Yang's avatar
Michael Yang committed
82

Michael Yang's avatar
Michael Yang committed
83
84
	query := sa.Query.Forward(ctx, hiddenState)
	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
85

Michael Yang's avatar
Michael Yang committed
86
87
	key := sa.Key.Forward(ctx, hiddenState)
	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
88

Michael Yang's avatar
Michael Yang committed
89
90
	value := sa.Value.Forward(ctx, hiddenState)
	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
91

Michael Yang's avatar
Michael Yang committed
92
93
	query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
	key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
Michael Yang's avatar
Michael Yang committed
94

Michael Yang's avatar
Michael Yang committed
95
96
97
	attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
	attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
	return sa.Output.Forward(ctx, attention)
Michael Yang's avatar
Michael Yang committed
98
99
}

Jesse Gross's avatar
Jesse Gross committed
100
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
Michael Yang's avatar
Michael Yang committed
101
102
	ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
	return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
Jesse Gross's avatar
Jesse Gross committed
103
104
}

Michael Yang's avatar
Michael Yang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
type MLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type Layer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *SelfAttention
	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
	MLP           *MLP
}

Michael Yang's avatar
Michael Yang committed
123
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
124
125
126
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
Michael Yang's avatar
Michael Yang committed
127
	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts)
128
129
130
131
132
133
134
135

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Michael Yang's avatar
Michael Yang committed
136
137
138
139
140
141
142
143
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

Jesse Gross's avatar
Jesse Gross committed
144
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
145
	positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
146

147
	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
Michael Yang's avatar
Michael Yang committed
148
149

	for i, layer := range m.Layers {
Jesse Gross's avatar
Jesse Gross committed
150
		m.Cache.SetLayer(i)
Michael Yang's avatar
Michael Yang committed
151

Michael Yang's avatar
Michael Yang committed
152
		var outputs ml.Tensor
153
		if i == len(m.Layers)-1 {
154
			outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
155
		}
Michael Yang's avatar
Michael Yang committed
156

Michael Yang's avatar
Michael Yang committed
157
		hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
Michael Yang's avatar
Michael Yang committed
158
159
	}

160
161
	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState), nil
Michael Yang's avatar
Michael Yang committed
162
163
164
165
166
}

func init() {
	model.Register("llama", New)
}