model.go 5.43 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
package llama

import (
4
	"fmt"
Michael Yang's avatar
Michael Yang committed
5
	"math"
6
	"strings"
Michael Yang's avatar
Michael Yang committed
7

Jesse Gross's avatar
Jesse Gross committed
8
	"github.com/ollama/ollama/kvcache"
Michael Yang's avatar
Michael Yang committed
9
10
11
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn"
	"github.com/ollama/ollama/model"
12
	"github.com/ollama/ollama/model/input"
Michael Yang's avatar
Michael Yang committed
13
14
15
)

type Options struct {
16
	hiddenSize, numHeads, numKVHeads int
Michael Yang's avatar
Michael Yang committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
	eps, ropeBase, ropeScale         float32
	ropeDim                          uint32
}

type Model struct {
	model.Base
	model.BytePairEncoding

	TokenEmbedding *nn.Embedding `gguf:"token_embd"`
	Layers         []Layer       `gguf:"blk"`
	OutputNorm     *nn.RMSNorm   `gguf:"output_norm"`
	Output         *nn.Linear    `gguf:"output,alt:token_embd"`

	*Options
}

func New(c ml.Config) (model.Model, error) {
34
35
36
37
	if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
		return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
	}

Jesse Gross's avatar
Jesse Gross committed
38
	m := Model{
Michael Yang's avatar
Michael Yang committed
39
40
41
42
43
44
		BytePairEncoding: model.NewBytePairEncoding(
			c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
			&model.Vocabulary{
				Values: c.Strings("tokenizer.ggml.tokens"),
				Types:  c.Uints("tokenizer.ggml.token_type"),
				Merges: c.Strings("tokenizer.ggml.merges"),
45
				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
46
				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
47
				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
48
				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
Michael Yang's avatar
Michael Yang committed
49
50
51
52
			},
		),
		Layers: make([]Layer, c.Uint("block_count")),
		Options: &Options{
53
54
55
			hiddenSize: int(c.Uint("embedding_length")),
			numHeads:   int(c.Uint("attention.head_count")),
			numKVHeads: int(c.Uint("attention.head_count_kv")),
Michael Yang's avatar
Michael Yang committed
56
57
58
59
60
			eps:        c.Float("attention.layer_norm_rms_epsilon"),
			ropeBase:   c.Float("rope.freq_base"),
			ropeScale:  c.Float("rope.freq_scale", 1),
			ropeDim:    c.Uint("rope.dimension_count"),
		},
Jesse Gross's avatar
Jesse Gross committed
61
62
63
64
65
	}

	m.Cache = kvcache.NewCausalCache(m.Shift)

	return &m, nil
Michael Yang's avatar
Michael Yang committed
66
67
68
}

type SelfAttention struct {
69
70
71
72
73
	Query       *nn.Linear `gguf:"attn_q"`
	Key         *nn.Linear `gguf:"attn_k"`
	Value       *nn.Linear `gguf:"attn_v"`
	Output      *nn.Linear `gguf:"attn_output"`
	RopeFactors ml.Tensor  `gguf:"rope_freqs.weight"`
Michael Yang's avatar
Michael Yang committed
74
75
}

Jesse Gross's avatar
Jesse Gross committed
76
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
77
78
	batchSize := hiddenState.Dim(1)
	headDim := opts.hiddenSize / opts.numHeads
Patrick Devine's avatar
Patrick Devine committed
79
	ropeType := uint32(0)
Michael Yang's avatar
Michael Yang committed
80
81
82

	q := sa.Query.Forward(ctx, hiddenState)
	q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
Patrick Devine's avatar
Patrick Devine committed
83
	q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
Michael Yang's avatar
Michael Yang committed
84
85
86

	k := sa.Key.Forward(ctx, hiddenState)
	k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
Patrick Devine's avatar
Patrick Devine committed
87
	k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
Michael Yang's avatar
Michael Yang committed
88
89
90
91

	v := sa.Value.Forward(ctx, hiddenState)
	v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)

92
	scaleFactor := 1.0 / math.Sqrt(float64(headDim))
93
	kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
Michael Yang's avatar
Michael Yang committed
94
95
96
97
98
	kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)

	return sa.Output.Forward(ctx, kqv)
}

Jesse Gross's avatar
Jesse Gross committed
99
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
Patrick Devine's avatar
Patrick Devine committed
100
	return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
Jesse Gross's avatar
Jesse Gross committed
101
102
}

Michael Yang's avatar
Michael Yang committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
type MLP struct {
	Up   *nn.Linear `gguf:"ffn_up"`
	Down *nn.Linear `gguf:"ffn_down"`
	Gate *nn.Linear `gguf:"ffn_gate"`
}

func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
	hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
	return mlp.Down.Forward(ctx, hiddenState)
}

type Layer struct {
	AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
	SelfAttention *SelfAttention
	MLPNorm       *nn.RMSNorm `gguf:"ffn_norm"`
	MLP           *MLP
}

121
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
Michael Yang's avatar
Michael Yang committed
122
123
124
125
	residual := hiddenState

	hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
126
127
128
129
130
131
132
133

	// In the final layer (outputs != nil), optimize by pruning to just the token positions
	// we need logits for.
	if outputs != nil {
		hiddenState = hiddenState.Rows(ctx, outputs)
		residual = residual.Rows(ctx, outputs)
	}

Michael Yang's avatar
Michael Yang committed
134
135
136
137
138
139
140
141
	hiddenState = hiddenState.Add(ctx, residual)
	residual = hiddenState

	hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
	hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
	return hiddenState.Add(ctx, residual)
}

142
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
143
	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
Michael Yang's avatar
Michael Yang committed
144
145
146
147
	if err != nil {
		return nil, err
	}

148
	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
Michael Yang's avatar
Michael Yang committed
149
150
151
152
	if err != nil {
		return nil, err
	}

153
	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
154
155
156
157
	if err != nil {
		return nil, err
	}

Michael Yang's avatar
Michael Yang committed
158
159
160
	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)

	for i, layer := range m.Layers {
Jesse Gross's avatar
Jesse Gross committed
161
		m.Cache.SetLayer(i)
Michael Yang's avatar
Michael Yang committed
162

163
164
165
166
		var lastLayerOutputs ml.Tensor
		if i == len(m.Layers)-1 {
			lastLayerOutputs = outputs
		}
Michael Yang's avatar
Michael Yang committed
167

168
		hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
Michael Yang's avatar
Michael Yang committed
169
170
	}

171
172
	hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
	return m.Output.Forward(ctx, hiddenState), nil
Michael Yang's avatar
Michael Yang committed
173
174
175
176
177
}

func init() {
	model.Register("llama", New)
}