text_encoder.go 8.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
//go:build mlx

package zimage

import (
	"encoding/json"
	"fmt"
	"math"
	"os"
	"path/filepath"

	"github.com/ollama/ollama/x/imagegen/mlx"
	"github.com/ollama/ollama/x/imagegen/nn"
	"github.com/ollama/ollama/x/imagegen/safetensors"
	"github.com/ollama/ollama/x/imagegen/tokenizer"
)

// Qwen3Config holds Qwen3 text encoder configuration
type Qwen3Config struct {
	HiddenSize        int32   `json:"hidden_size"`
	NumHiddenLayers   int32   `json:"num_hidden_layers"`
	IntermediateSize  int32   `json:"intermediate_size"`
	NumAttentionHeads int32   `json:"num_attention_heads"`
	NumKeyValueHeads  int32   `json:"num_key_value_heads"`
	VocabSize         int32   `json:"vocab_size"`
	RMSNormEps        float32 `json:"rms_norm_eps"`
	RopeTheta         float32 `json:"rope_theta"`
	HeadDim           int32   `json:"head_dim"`
}

// loadQwen3Config loads text encoder config from a JSON file
func loadQwen3Config(path string) (*Qwen3Config, error) {
	data, err := os.ReadFile(path)
	if err != nil {
		return nil, fmt.Errorf("read config: %w", err)
	}
	var cfg Qwen3Config
	if err := json.Unmarshal(data, &cfg); err != nil {
		return nil, fmt.Errorf("parse config: %w", err)
	}
	return &cfg, nil
}

// Qwen3Attention implements Qwen3 attention with QK norms
type Qwen3Attention struct {
	QProj *nn.Linear  `weight:"q_proj"`
	KProj *nn.Linear  `weight:"k_proj"`
	VProj *nn.Linear  `weight:"v_proj"`
	OProj *nn.Linear  `weight:"o_proj"`
	QNorm *nn.RMSNorm `weight:"q_norm"`
	KNorm *nn.RMSNorm `weight:"k_norm"`
	// Computed fields
	NHeads    int32
	NKVHeads  int32
	HeadDim   int32
	Scale     float32
	RopeTheta float32
}

// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
	shape := x.Shape()
	B := shape[0]
	L := shape[1]
	H := shape[2]
	D := shape[3]
	half := D / 2

	freqsArr := make([]float32, half)
	logTheta := float32(math.Log(float64(theta)))
	for i := int32(0); i < half; i++ {
		freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
	}
	freqs := mlx.NewArray(freqsArr, []int32{half})

	posArr := make([]float32, seqLen)
	for i := int32(0); i < seqLen; i++ {
		posArr[i] = float32(i)
	}
	pos := mlx.NewArray(posArr, []int32{seqLen})

	posExpanded := mlx.Reshape(pos, seqLen, 1)
	freqsExpanded := mlx.Reshape(freqs, 1, half)
	args := mlx.Mul(posExpanded, freqsExpanded)

	cosVals := mlx.Cos(args)
	sinVals := mlx.Sin(args)
	cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
	sinVals = mlx.Reshape(sinVals, seqLen, 1, half)

	x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
	x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})

	part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
	part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))

	return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}

// Forward computes attention with causal masking
func (attn *Qwen3Attention) Forward(x *mlx.Array) *mlx.Array {
	shape := x.Shape()
	B := shape[0]
	L := shape[1]

	q := attn.QProj.Forward(x)
	k := attn.KProj.Forward(x)
	v := attn.VProj.Forward(x)

	q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
	k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
	v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)

	// QK norm uses 1e-6 hardcoded (Qwen3 specific)
	q = attn.QNorm.Forward(q, 1e-6)
	k = attn.KNorm.Forward(k, 1e-6)

	q = applyRoPEQwen3(q, L, attn.RopeTheta)
	k = applyRoPEQwen3(k, L, attn.RopeTheta)

	q = mlx.Transpose(q, 0, 2, 1, 3)
	k = mlx.Transpose(k, 0, 2, 1, 3)
	v = mlx.Transpose(v, 0, 2, 1, 3)

	if attn.NKVHeads < attn.NHeads {
		repeats := attn.NHeads / attn.NKVHeads
		k = repeatKV(k, repeats)
		v = repeatKV(v, repeats)
	}

	out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)

	out = mlx.Transpose(out, 0, 2, 1, 3)
	out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)

	out = attn.OProj.Forward(out)

	return out
}

// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
	if repeats == 1 {
		return x
	}
	shape := x.Shape()
	x = mlx.ExpandDims(x, 2)
	x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
	return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}

// Qwen3MLP implements Qwen3 SwiGLU MLP
type Qwen3MLP struct {
	GateProj *nn.Linear `weight:"gate_proj"`
	UpProj   *nn.Linear `weight:"up_proj"`
	DownProj *nn.Linear `weight:"down_proj"`
}

// Forward applies the MLP
func (m *Qwen3MLP) Forward(x *mlx.Array) *mlx.Array {
	gate := m.GateProj.Forward(x)
	gate = mlx.SiLU(gate)
	up := m.UpProj.Forward(x)
	h := mlx.Mul(gate, up)
	return m.DownProj.Forward(h)
}

// Qwen3Block represents a single Qwen3 transformer block
type Qwen3Block struct {
	Attention         *Qwen3Attention `weight:"self_attn"`
	MLP               *Qwen3MLP       `weight:"mlp"`
	InputLayerNorm    *nn.RMSNorm     `weight:"input_layernorm"`
	PostAttnLayerNorm *nn.RMSNorm     `weight:"post_attention_layernorm"`
}

// Forward applies the Qwen3 block
func (qb *Qwen3Block) Forward(x *mlx.Array, eps float32) *mlx.Array {
	h := qb.InputLayerNorm.Forward(x, eps)
	attnOut := qb.Attention.Forward(h)
	x = mlx.Add(x, attnOut)

	h = qb.PostAttnLayerNorm.Forward(x, eps)
	mlpOut := qb.MLP.Forward(h)
	x = mlx.Add(x, mlpOut)

	return x
}

// Qwen3TextEncoder is the full Qwen3 encoder for Z-Image
type Qwen3TextEncoder struct {
	EmbedTokens *nn.Embedding   `weight:"model.embed_tokens"`
	Layers      []*Qwen3Block   `weight:"model.layers"`
	FinalNorm   *nn.RMSNorm     `weight:"model.norm"`
	*Qwen3Config
}

// Load loads the Qwen3 text encoder from a directory
func (m *Qwen3TextEncoder) Load(path string) error {
	fmt.Println("Loading Qwen3 text encoder...")

	// Load config
	cfg, err := loadQwen3Config(filepath.Join(path, "config.json"))
	if err != nil {
		return fmt.Errorf("config: %w", err)
	}
	m.Qwen3Config = cfg

	// Pre-allocate layers slice
	m.Layers = make([]*Qwen3Block, cfg.NumHiddenLayers)

	// Load weights
	weights, err := safetensors.LoadModelWeights(path)
	if err != nil {
		return fmt.Errorf("weights: %w", err)
	}

	fmt.Print("  Loading weights via struct tags... ")
	if err := safetensors.LoadModule(m, weights, ""); err != nil {
		return fmt.Errorf("load module: %w", err)
	}
	fmt.Println("✓")

	// Initialize computed fields
	m.FinalNorm.Eps = cfg.RMSNormEps
	for _, block := range m.Layers {
		// Attention
		block.Attention.NHeads = cfg.NumAttentionHeads
		block.Attention.NKVHeads = cfg.NumKeyValueHeads
		block.Attention.HeadDim = cfg.HeadDim
		block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
		block.Attention.RopeTheta = cfg.RopeTheta
		block.Attention.QNorm.Eps = cfg.RMSNormEps
		block.Attention.KNorm.Eps = cfg.RMSNormEps
		// Block norms
		block.InputLayerNorm.Eps = cfg.RMSNormEps
		block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
	}

	weights.ReleaseAll()
	return nil
}

// Forward encodes text tokens
func (te *Qwen3TextEncoder) Forward(tokens *mlx.Array) *mlx.Array {
	h := te.EmbedTokens.Forward(tokens)
	eps := te.RMSNormEps

	for _, layer := range te.Layers {
		h = layer.Forward(h, eps)
	}

	// Apply final RMS norm
	h = te.FinalNorm.Forward(h, eps)

	return h
}

// ApplyChatTemplate wraps prompt in Qwen3 chat format
func ApplyChatTemplate(prompt string) string {
	return "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
}

// EncodePrompt encodes a text prompt using the tokenizer and encoder
func (te *Qwen3TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int) (*mlx.Array, *mlx.Array) {
	formattedPrompt := ApplyChatTemplate(prompt)

	tokens := tok.Encode(formattedPrompt, false)

	if len(tokens) > maxLen {
		tokens = tokens[:maxLen]
	}

	maskData := make([]float32, maxLen)
	for i := 0; i < len(tokens); i++ {
		maskData[i] = 1.0
	}

	// Get PAD token (different from EOS for Qwen3)
	padToken := tok.PAD()
	if padToken < 0 {
		padToken = tok.EOS() // fallback
	}

	paddedTokens := make([]int32, maxLen)
	copy(paddedTokens, tokens)
	for i := len(tokens); i < maxLen; i++ {
		paddedTokens[i] = padToken
	}

	tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
	maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})

	embeddings := te.Forward(tokensArr)

	return embeddings, maskArr
}