nn.go 6.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
//go:build mlx

// Package nn provides neural network layer types.
package nn

import "github.com/ollama/ollama/x/imagegen/mlx"

// Layer is the interface for neural network layers with a Forward method.
type Layer interface {
	Forward(x *mlx.Array) *mlx.Array
}

// Linear applies an affine transformation: y = x @ W.T + b
// Weight is stored as [out_features, in_features], matching PyTorch/MLX convention.
type Linear struct {
	Weight *mlx.Array `weight:"weight"`          // [out_features, in_features]
	Bias   *mlx.Array `weight:"bias,optional"`   // [out_features] or nil
}

// NewLinear creates a linear layer.
// Weight should be [out_features, in_features].
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
	return &Linear{Weight: weight, Bias: bias}
}

// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
	qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
	// Eval immediately so bf16 weight can be freed
	mlx.Eval(qw, scales, qbiases)
	return &QuantizedLinear{
		Weight:    qw,
		Scales:    scales,
		QBiases:   qbiases,
		Bias:      bias,
		GroupSize: groupSize,
		Bits:      bits,
		Mode:      mode,
	}
}

// Forward applies the linear transformation: x @ W.T + bias
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
	w := mlx.Transpose(l.Weight, 1, 0)
	if l.Bias != nil {
		return mlx.AddMM(l.Bias, x, w, 1.0, 1.0)
	}
	return mlx.Linear(x, w)
}

// ToQuantized converts this Linear to a QuantizedLinear.
func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear {
	qw, scales, qbiases := mlx.Quantize(l.Weight, groupSize, bits, mode)
	return &QuantizedLinear{
		Weight:    qw,
		Scales:    scales,
		QBiases:   qbiases,
		Bias:      l.Bias,
		GroupSize: groupSize,
		Bits:      bits,
		Mode:      mode,
	}
}

// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
type QuantizedLinear struct {
	Weight    *mlx.Array // Quantized weight data
	Scales    *mlx.Array // Scale factors for dequantization
	QBiases   *mlx.Array // Quantization biases (NOT layer bias)
	Bias      *mlx.Array // Layer bias [output_dims] or nil
	GroupSize int
	Bits      int
	Mode      string
}

// Forward applies the quantized linear transformation.
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
	out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
	if ql.Bias != nil {
		out = mlx.Add(out, ql.Bias)
	}
	return out
}

// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
	Weight *mlx.Array `weight:"weight"`
	Eps    float32    // optional: used if Forward called with eps=0
}

// NewRMSNorm creates an RMSNorm layer (for models not using weight loader).
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
	return &RMSNorm{Weight: weight, Eps: eps}
}

// Forward applies RMS normalization. If eps=0, uses stored Eps.
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
	if eps == 0 {
		eps = rn.Eps
	}
	return mlx.RMSNorm(x, rn.Weight, eps)
}

// Embedding represents an embedding layer.
type Embedding struct {
	Weight *mlx.Array `weight:"weight"`
}

// NewEmbedding creates an embedding layer.
func NewEmbedding(weight *mlx.Array) *Embedding {
	return &Embedding{Weight: weight}
}

// Forward looks up embeddings by indices.
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
	return mlx.Take(e.Weight, indices, 0)
}

// RepeatKV repeats K/V tensors for grouped query attention
// x: [B, num_kv_heads, S, head_dim] -> [B, num_heads, S, head_dim]
func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
	if repeatFactor == 1 {
		return x
	}
	shape := x.Shape()
	// [B, num_kv_heads, S, head_dim] -> [B, num_kv_heads, 1, S, head_dim]
	x = mlx.ExpandDims(x, 2)
	// Repeat along the new axis
	reps := []int32{1, 1, repeatFactor, 1, 1}
	x = mlx.Tile(x, reps)
	// Reshape: [B, num_kv_heads, repeat, S, head_dim] -> [B, num_kv_heads * repeat, S, head_dim]
	return mlx.Reshape(x, shape[0], shape[1]*repeatFactor, shape[2], shape[3])
}

// ApplyCausalMask applies causal (lower triangular) mask to attention scores
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
	// scores: [B, num_heads, S, S]
	shape := scores.Shape()
	seqLen := shape[2]

	// Create causal mask: 1 for positions to keep, 0 for positions to mask
	mask := mlx.Tri(seqLen, seqLen, 0)

	// Where mask is 0, set score to -inf
	negInf := mlx.NewScalarArray(float32(-1e9))

	// Broadcast mask to match scores shape
	mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, S, S]

	// Use where: if mask > 0, keep scores, else -inf
	return mlx.Where(mask, scores, negInf)
}

// ApplyCausalMaskWithOffset applies causal mask for cached attention
// scores: [B, num_heads, queryLen, keyLen] where keyLen = cacheLen + queryLen
// offset: the starting position of the new queries (i.e., cache length)
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
	if offset == 0 {
		return ApplyCausalMask(scores)
	}

	shape := scores.Shape()
	queryLen := shape[2]
	keyLen := shape[3]

	// For cached attention, new queries can attend to all cached keys plus
	// new keys up to and including their position.
	mask := mlx.Tri(queryLen, keyLen, int(offset))

	negInf := mlx.NewScalarArray(float32(-1e9))
	mask = mlx.ExpandDims(mlx.ExpandDims(mask, 0), 0) // [1, 1, queryLen, keyLen]

	return mlx.Where(mask, scores, negInf)
}

// LayerNorm represents a standard layer normalization layer (with bias).
type LayerNorm struct {
	Weight *mlx.Array `weight:"weight"`
	Bias   *mlx.Array `weight:"bias"`
	Eps    float32
}

// Forward applies layer normalization: (x - mean) / sqrt(var + eps) * weight + bias
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
	eps := ln.Eps
	if eps == 0 {
		eps = 1e-5
	}
	// Compute mean and variance along last dimension
	mean := mlx.Mean(x, -1, true)
	centered := mlx.Sub(x, mean)
	variance := mlx.Mean(mlx.Mul(centered, centered), -1, true)
	normalized := mlx.Mul(centered, mlx.RSqrt(mlx.AddScalar(variance, eps)))

	// Scale and shift
	out := mlx.Mul(normalized, ln.Weight)
	if ln.Bias != nil {
		out = mlx.Add(out, ln.Bias)
	}
	return out
}