model.go 5.75 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package llama

import (
4
	"cmp"
Michael Yang's avatar
Michael Yang committed
5
6
	"math"

7
	"github.com/ollama/ollama/fs"
Jesse Gross's avatar
Jesse Gross committed
8
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
9
10
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
11
12
	"github.com/ollama/ollama/ml/nn/fast"
	"github.com/ollama/ollama/ml/nn/rope"
Michael Yang's avatar
Michael Yang committed
13
	"github.com/ollama/ollama/model"
14
	"github.com/ollama/ollama/model/input"
Michael Yang's avatar
Michael Yang committed
15
16
17
)

type Options struct {
18
19
20
	hiddenSize, numHeads, numKVHeads int
	headDim, ropeDim                 int
	eps, ropeBase, ropeScale         float32
Michael Yang's avatar
Michael Yang committed
21
22
23
24
}

type Model struct {
	model.Base
25
	model.TextProcessor
Michael Yang's avatar
Michael Yang committed
26
27
28
29
30
31

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

32
	Options
Michael Yang's avatar
Michael Yang committed
33
34
}

35
func New(c fs.Config) (model.Model, error) {
36
37
38
	if c.Uint("expert_count") > 0 {
		// TODO: support mixtures of experts
		return nil, model.ErrUnsupportedModel
39
	}
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

	var processor model.TextProcessor
	vocabulary := model.Vocabulary{
		Values: c.Strings("tokenizer.ggml.tokens"),
		Scores: c.Floats("tokenizer.ggml.scores"),
		Types:  c.Ints("tokenizer.ggml.token_type"),
		Merges: c.Strings("tokenizer.ggml.merges"),
		AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
		BOS:    []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
		AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
		EOS: append(
			[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
			c.Ints("tokenizer.ggml.eos_token_ids")...,
		),
	}
	switch c.String("tokenizer.ggml.model") {
	case "gpt2":
		processor = model.NewBytePairEncoding(
			`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
			&vocabulary,
		)
	case "llama":
		processor = model.NewSentencePiece(&vocabulary)
	default:
		return nil, model.ErrUnsupportedTokenizer
65
	}
66

Jesse Gross's avatar
Jesse Gross committed
67
	m := Model{
68
69
70
		TextProcessor: processor,
		Layers:        make([]Layer, c.Uint("block_count")),
		Options: Options{
71
72
73
			hiddenSize: int(c.Uint("embedding_length")),
			numHeads:   int(c.Uint("attention.head_count")),
			numKVHeads: int(c.Uint("attention.head_count_kv")),
74
			headDim:    int(c.Uint("attention.key_length")),
75
			ropeDim:    int(c.Uint("rope.dimension_count")),
Michael Yang's avatar
Michael Yang committed
76
			eps:        c.Float("attention.layer_norm_rms_epsilon"),
77
78
			ropeBase:   c.Float("rope.freq_base", 1e5),
			ropeScale:  c.Float("rope.scaling.factor", 1),
Michael Yang's avatar
Michael Yang committed
79
		},
Jesse Gross's avatar
Jesse Gross committed
80
81
82
83
84
	}

	m.Cache = kvcache.NewCausalCache(m.Shift)

	return &m, nil
Michael Yang's avatar
Michael Yang committed
85
86
87
}

type SelfAttention struct {
88
89
90
91
92
	Query       *nn.Linear `gguf:"attn_q"`
	Key         *nn.Linear `gguf:"attn_k"`
	Value       *nn.Linear `gguf:"attn_v"`
	Output      *nn.Linear `gguf:"attn_output"`
	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
Michael Yang's avatar
Michael Yang committed
93
94
}

Michael Yang's avatar
Michael Yang committed
95
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
96
	batchSize := hiddenState.Dim(1)
97
	headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
Michael Yang's avatar
Michael Yang committed
98
	ropeDim := cmp.Or(opts.ropeDim, headDim)
Michael Yang's avatar
Michael Yang committed
99

Michael Yang's avatar
Michael Yang committed
100
101
	query := sa.Query.Forward(ctx, hiddenState)
	query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
102

Michael Yang's avatar
Michael Yang committed
103
104
	key := sa.Key.Forward(ctx, hiddenState)
	key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
105

Michael Yang's avatar
Michael Yang committed
106
107
	value := sa.Value.Forward(ctx, hiddenState)
	value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Michael Yang's avatar
Michael Yang committed
108

109
110
	query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
	key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
Michael Yang's avatar
Michael Yang committed
111

Michael Yang's avatar
Michael Yang committed
112
113
114
	attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
	attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
	return sa.Output.Forward(ctx, attention)
Michael Yang's avatar
Michael Yang committed
115
116
}

Jesse Gross's avatar
Jesse Gross committed
117
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
Michael Yang's avatar
Michael Yang committed
118
	ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
119
	return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
Jesse Gross's avatar
Jesse Gross committed
120
121
}

Michael Yang's avatar
Michael Yang committed
122
123
124
125
126
127
128
type MLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
129
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
Michael Yang's avatar
Michael Yang committed
130
131
132
133
134
135
136
137
138
139
	return mlp.Down.Forward(ctx, hiddenState)
}

type Layer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *SelfAttention
	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
	MLP           *MLP
}

Michael Yang's avatar
Michael Yang committed
140
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
141
142
143
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
Michael Yang's avatar
Michael Yang committed
144
	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts)
145
146
147
148
149
150
151
152

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Michael Yang's avatar
Michael Yang committed
153
154
155
156
157
158
159
160
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

Jesse Gross's avatar
Jesse Gross committed
161
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
162
	positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
Michael Yang's avatar
Michael Yang committed
163

164
	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
Michael Yang's avatar
Michael Yang committed
165
166

	for i, layer := range m.Layers {
Jesse Gross's avatar
Jesse Gross committed
167
		m.Cache.SetLayer(i)
Michael Yang's avatar
Michael Yang committed
168

Michael Yang's avatar
Michael Yang committed
169
		var outputs ml.Tensor
170
		if i == len(m.Layers)-1 {
171
			outputs = batch.Outputs
172
		}
Michael Yang's avatar
Michael Yang committed
173

174
		hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
Michael Yang's avatar
Michael Yang committed
175
176
	}

177
178
	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState), nil
Michael Yang's avatar
Michael Yang committed
179
180
181
182
183
}

func init() {
	model.Register("llama", New)
}