transformer.go 21.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
//go:build mlx

// Package zimage implements the Z-Image diffusion transformer model.
package zimage

import (
	"encoding/json"
	"fmt"
	"math"
	"os"
	"path/filepath"

	"github.com/ollama/ollama/x/imagegen/cache"
	"github.com/ollama/ollama/x/imagegen/mlx"
	"github.com/ollama/ollama/x/imagegen/nn"
	"github.com/ollama/ollama/x/imagegen/safetensors"
)

// TransformerConfig holds Z-Image transformer configuration
type TransformerConfig struct {
	Dim            int32   `json:"dim"`
	NHeads         int32   `json:"n_heads"`
	NKVHeads       int32   `json:"n_kv_heads"`
	NLayers        int32   `json:"n_layers"`
	NRefinerLayers int32   `json:"n_refiner_layers"`
	InChannels     int32   `json:"in_channels"`
	PatchSize      int32   `json:"-"` // Computed from AllPatchSize
	CapFeatDim     int32   `json:"cap_feat_dim"`
	NormEps        float32 `json:"norm_eps"`
	RopeTheta      float32 `json:"rope_theta"`
	TScale         float32 `json:"t_scale"`
	QKNorm         bool    `json:"qk_norm"`
	AxesDims       []int32 `json:"axes_dims"`
	AxesLens       []int32 `json:"axes_lens"`
	AllPatchSize   []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
}

// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
	Linear1       *nn.Linear `weight:"mlp.0"`
	Linear2       *nn.Linear `weight:"mlp.2"`
	FreqEmbedSize int32      // 256 (computed)
}

// Forward computes timestep embeddings -> [B, 256]
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
	// t: [B] timesteps

	// Create sinusoidal embedding
	half := te.FreqEmbedSize / 2

	// freqs = exp(-log(10000) * arange(half) / half)
	freqs := make([]float32, half)
	for i := int32(0); i < half; i++ {
		freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
	}
	freqsArr := mlx.NewArray(freqs, []int32{1, half})

	// t[:, None] * freqs[None, :] -> [B, half]
	tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
	args := mlx.Mul(tExpanded, freqsArr)

	// embedding = [cos(args), sin(args)] -> [B, 256]
	cosArgs := mlx.Cos(args)
	sinArgs := mlx.Sin(args)
	embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)

	// MLP: linear1 -> silu -> linear2
	h := te.Linear1.Forward(embedding)
	h = mlx.SiLU(h)
	h = te.Linear2.Forward(h)

	return h
}

// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
	Linear *nn.Linear `weight:"2-1"`
}

// Forward embeds patchified image latents
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
	// x: [B, L, in_channels * 4] -> [B, L, dim]
	return xe.Linear.Forward(x)
}

// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
	Norm     *nn.RMSNorm `weight:"0"`
	Linear   *nn.Linear  `weight:"1"`
	PadToken *mlx.Array  // loaded separately at root level
}

// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
	// RMSNorm on last axis (uses 1e-6)
	h := ce.Norm.Forward(capFeats, 1e-6)
	// Linear projection
	return ce.Linear.Forward(h)
}

// FeedForward implements SwiGLU FFN
type FeedForward struct {
	W1     *nn.Linear `weight:"w1"` // gate projection
	W2     *nn.Linear `weight:"w2"` // down projection
	W3     *nn.Linear `weight:"w3"` // up projection
	OutDim int32      // computed from W2
}

// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
	shape := x.Shape()
	B := shape[0]
	L := shape[1]
	D := shape[2]

	// Reshape for matmul
	x = mlx.Reshape(x, B*L, D)
	gate := ff.W1.Forward(x)
	gate = mlx.SiLU(gate)
	up := ff.W3.Forward(x)
	h := mlx.Mul(gate, up)
	out := ff.W2.Forward(h)

	return mlx.Reshape(out, B, L, ff.OutDim)
}

// Attention implements multi-head attention with QK norm
type Attention struct {
	ToQ   *nn.Linear `weight:"to_q"`
	ToK   *nn.Linear `weight:"to_k"`
	ToV   *nn.Linear `weight:"to_v"`
	ToOut *nn.Linear `weight:"to_out.0"`
	NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
	NormK *mlx.Array `weight:"norm_k.weight"`
	// Computed fields
	NHeads  int32
	HeadDim int32
	Dim     int32
	Scale   float32
}

// Forward computes attention
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
	shape := x.Shape()
	B := shape[0]
	L := shape[1]
	D := shape[2]

	// Project Q, K, V
	xFlat := mlx.Reshape(x, B*L, D)
	q := attn.ToQ.Forward(xFlat)
	k := attn.ToK.Forward(xFlat)
	v := attn.ToV.Forward(xFlat)

	// Reshape to [B, L, nheads, head_dim]
	q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
	k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
	v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)

	// QK norm
	q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
	k = mlx.RMSNorm(k, attn.NormK, 1e-5)

	// Apply RoPE if provided
	if cos != nil && sin != nil {
		q = applyRoPE3D(q, cos, sin)
		k = applyRoPE3D(k, cos, sin)
	}

	// Transpose to [B, nheads, L, head_dim]
	q = mlx.Transpose(q, 0, 2, 1, 3)
	k = mlx.Transpose(k, 0, 2, 1, 3)
	v = mlx.Transpose(v, 0, 2, 1, 3)

	// SDPA
	out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)

	// Transpose back and reshape
	out = mlx.Transpose(out, 0, 2, 1, 3)
	out = mlx.Reshape(out, B*L, attn.Dim)
	out = attn.ToOut.Forward(out)

	return mlx.Reshape(out, B, L, attn.Dim)
}

// applyRoPE3D applies 3-axis rotary position embeddings
// x: [B, L, nheads, head_dim]
// cos, sin: [B, L, 1, head_dim/2]
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
	shape := x.Shape()
	B := shape[0]
	L := shape[1]
	nheads := shape[2]
	headDim := shape[3]
	half := headDim / 2

	// Create even/odd index arrays
	evenIdx := make([]int32, half)
	oddIdx := make([]int32, half)
	for i := int32(0); i < half; i++ {
		evenIdx[i] = i * 2
		oddIdx[i] = i*2 + 1
	}
	evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
	oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})

	// Extract x1 (even indices) and x2 (odd indices) along last axis
	x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
	x2 := mlx.Take(x, oddIndices, 3)  // [B, L, nheads, half]

	// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
	r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
	r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))

	// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
	r1 = mlx.ExpandDims(r1, 4)                          // [B, L, nheads, half, 1]
	r2 = mlx.ExpandDims(r2, 4)                          // [B, L, nheads, half, 1]
	stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
	return mlx.Reshape(stacked, B, L, nheads, headDim)
}

// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
	Attention      *Attention   `weight:"attention"`
	FeedForward    *FeedForward `weight:"feed_forward"`
	AttentionNorm1 *nn.RMSNorm  `weight:"attention_norm1"`
	AttentionNorm2 *nn.RMSNorm  `weight:"attention_norm2"`
	FFNNorm1       *nn.RMSNorm  `weight:"ffn_norm1"`
	FFNNorm2       *nn.RMSNorm  `weight:"ffn_norm2"`
	AdaLN          *nn.Linear   `weight:"adaLN_modulation.0,optional"` // only if modulation
	// Computed fields
	HasModulation bool
	Dim           int32
}

// Forward applies the transformer block
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
	if tb.AdaLN != nil && adaln != nil {
		// Compute modulation: [B, 256] -> [B, 4*dim]
		chunks := tb.AdaLN.Forward(adaln)

		// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
		chunkShape := chunks.Shape()
		chunkDim := chunkShape[1] / 4

		scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
		gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
		scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
		gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})

		// Expand for broadcasting: [B, 1, dim]
		scaleMSA = mlx.ExpandDims(scaleMSA, 1)
		gateMSA = mlx.ExpandDims(gateMSA, 1)
		scaleMLP = mlx.ExpandDims(scaleMLP, 1)
		gateMLP = mlx.ExpandDims(gateMLP, 1)

		// Attention with modulation
		normX := tb.AttentionNorm1.Forward(x, eps)
		normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
		attnOut := tb.Attention.Forward(normX, cos, sin)
		attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
		x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))

		// FFN with modulation
		normFFN := tb.FFNNorm1.Forward(x, eps)
		normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
		ffnOut := tb.FeedForward.Forward(normFFN)
		ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
		x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
	} else {
		// No modulation (context refiner)
		attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
		x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))

		ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
		x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
	}

	return x
}

// FinalLayer outputs the denoised patches
type FinalLayer struct {
	AdaLN  *nn.Linear `weight:"adaLN_modulation.1"` // [256] -> [dim]
	Output *nn.Linear `weight:"linear"`             // [dim] -> [out_channels]
	OutDim int32      // computed from Output
}

// Forward computes final output
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
	// c: [B, 256] -> scale: [B, dim]
	scale := mlx.SiLU(c)
	scale = fl.AdaLN.Forward(scale)
	scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]

	// LayerNorm (affine=False) then scale
	x = layerNormNoAffine(x, 1e-6)
	x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))

	// Output projection
	shape := x.Shape()
	B := shape[0]
	L := shape[1]
	D := shape[2]
	x = mlx.Reshape(x, B*L, D)
	x = fl.Output.Forward(x)

	return mlx.Reshape(x, B, L, fl.OutDim)
}

// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
	ndim := x.Ndim()
	lastAxis := ndim - 1

	mean := mlx.Mean(x, lastAxis, true)
	xCentered := mlx.Sub(x, mean)
	variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
	return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}

// Transformer is the full Z-Image DiT model
type Transformer struct {
	TEmbed          *TimestepEmbedder   `weight:"t_embedder"`
	XEmbed          *XEmbedder          `weight:"all_x_embedder"`
	CapEmbed        *CapEmbedder        `weight:"cap_embedder"`
	NoiseRefiners   []*TransformerBlock `weight:"noise_refiner"`
	ContextRefiners []*TransformerBlock `weight:"context_refiner"`
	Layers          []*TransformerBlock `weight:"layers"`
	FinalLayer      *FinalLayer         `weight:"all_final_layer.2-1"`
	XPadToken       *mlx.Array          `weight:"x_pad_token"`
	CapPadToken     *mlx.Array          `weight:"cap_pad_token"`
	*TransformerConfig
}

// Load loads the Z-Image transformer from a directory
func (m *Transformer) Load(path string) error {
	fmt.Println("Loading Z-Image transformer...")

	// Load config
	cfg, err := loadTransformerConfig(filepath.Join(path, "config.json"))
	if err != nil {
		return fmt.Errorf("config: %w", err)
	}
	m.TransformerConfig = cfg

	// Pre-allocate slices for loader
	m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
	m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
	m.Layers = make([]*TransformerBlock, cfg.NLayers)

	// Load weights
	weights, err := safetensors.LoadModelWeights(path)
	if err != nil {
		return fmt.Errorf("weights: %w", err)
	}

	fmt.Print("  Loading weights as bf16... ")
	if err := weights.Load(mlx.DtypeBFloat16); err != nil {
		return fmt.Errorf("load weights: %w", err)
	}
	fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))

	fmt.Print("  Loading weights via struct tags... ")
	if err := safetensors.LoadModule(m, weights, ""); err != nil {
		return fmt.Errorf("load module: %w", err)
	}
	fmt.Println("✓")

	// Initialize computed fields
	m.TEmbed.FreqEmbedSize = 256
	m.FinalLayer.OutDim = m.FinalLayer.Output.Weight.Shape()[0]
	m.CapEmbed.Norm.Eps = 1e-6

	for _, block := range m.NoiseRefiners {
		initTransformerBlock(block, cfg)
	}
	for _, block := range m.ContextRefiners {
		initTransformerBlock(block, cfg)
	}
	for _, block := range m.Layers {
		initTransformerBlock(block, cfg)
	}

	weights.ReleaseAll()
	return nil
}

// loadTransformerConfig loads transformer config from a JSON file
func loadTransformerConfig(path string) (*TransformerConfig, error) {
	data, err := os.ReadFile(path)
	if err != nil {
		return nil, fmt.Errorf("read config: %w", err)
	}
	var cfg TransformerConfig
	if err := json.Unmarshal(data, &cfg); err != nil {
		return nil, fmt.Errorf("parse config: %w", err)
	}
	// Extract PatchSize from array
	if len(cfg.AllPatchSize) > 0 {
		cfg.PatchSize = cfg.AllPatchSize[0]
	}
	return &cfg, nil
}

// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
	block.Dim = cfg.Dim
	block.HasModulation = block.AdaLN != nil

	// Init attention computed fields
	attn := block.Attention
	attn.NHeads = cfg.NHeads
	attn.HeadDim = cfg.Dim / cfg.NHeads
	attn.Dim = cfg.Dim
	attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))

	// Init feedforward OutDim
	block.FeedForward.OutDim = block.FeedForward.W2.Weight.Shape()[0]

	// Set eps on all RMSNorm layers
	block.AttentionNorm1.Eps = cfg.NormEps
	block.AttentionNorm2.Eps = cfg.NormEps
	block.FFNNorm1.Eps = cfg.NormEps
	block.FFNNorm2.Eps = cfg.NormEps
}

// RoPECache holds precomputed RoPE values
type RoPECache struct {
	ImgCos     *mlx.Array
	ImgSin     *mlx.Array
	CapCos     *mlx.Array
	CapSin     *mlx.Array
	UnifiedCos *mlx.Array
	UnifiedSin *mlx.Array
	ImgLen     int32
	CapLen     int32
}

// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
	imgLen := hTok * wTok

	// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
	imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
	imgPos = mlx.ToBFloat16(imgPos)
	// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
	capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
	capPos = mlx.ToBFloat16(capPos)

	// Compute RoPE from UNIFIED positions
	unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
	unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)

	// Slice RoPE for image and caption parts
	imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
	imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
	capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
	capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})

	return &RoPECache{
		ImgCos:     imgCos,
		ImgSin:     imgSin,
		CapCos:     capCos,
		CapSin:     capSin,
		UnifiedCos: unifiedCos,
		UnifiedSin: unifiedSin,
		ImgLen:     imgLen,
		CapLen:     capLen,
	}
}

// Forward runs the Z-Image transformer with precomputed RoPE
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
	imgLen := rope.ImgLen

	// Timestep embedding -> [B, 256]
	temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))

	// Embed image patches -> [B, L_img, dim]
	x = m.XEmbed.Forward(x)

	// Embed caption features -> [B, L_cap, dim]
	capEmb := m.CapEmbed.Forward(capFeats)

	eps := m.NormEps

	// Noise refiner: refine image patches with modulation
	for _, refiner := range m.NoiseRefiners {
		x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
	}

	// Context refiner: refine caption (no modulation)
	for _, refiner := range m.ContextRefiners {
		capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
	}

	// Concatenate image and caption for joint attention
	unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)

	// Main transformer layers use full unified RoPE
	for _, layer := range m.Layers {
		unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
	}

	// Extract image tokens only
	unifiedShape := unified.Shape()
	B := unifiedShape[0]
	imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})

	// Final layer
	return m.FinalLayer.Forward(imgOut, temb)
}

// ForwardWithCache runs the transformer with layer caching for faster inference.
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
func (m *Transformer) ForwardWithCache(
	x *mlx.Array,
	t *mlx.Array,
	capFeats *mlx.Array,
	rope *RoPECache,
	stepCache *cache.StepCache,
	step int,
	cacheInterval int,
) *mlx.Array {
	imgLen := rope.ImgLen
	cacheLayers := stepCache.NumLayers()
	eps := m.NormEps

	// Timestep embedding -> [B, 256]
	temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))

	// Embed image patches -> [B, L_img, dim]
	x = m.XEmbed.Forward(x)

	// Context refiners: compute once on step 0, reuse forever
	// (caption embedding doesn't depend on timestep or latents)
	var capEmb *mlx.Array
	if stepCache.GetConstant() != nil {
		capEmb = stepCache.GetConstant()
	} else {
		capEmb = m.CapEmbed.Forward(capFeats)
		for _, refiner := range m.ContextRefiners {
			capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
		}
		stepCache.SetConstant(capEmb)
	}

	// Noise refiners: always compute (depend on x which changes each step)
	for _, refiner := range m.NoiseRefiners {
		x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
	}

	// Concatenate image and caption for joint attention
	unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)

	// Determine if this is a cache refresh step
	refreshCache := stepCache.ShouldRefresh(step, cacheInterval)

	// Main transformer layers with caching
	for i, layer := range m.Layers {
		if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
			// Use cached output for shallow layers
			unified = stepCache.Get(i)
		} else {
			// Compute layer
			unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
			// Cache shallow layer outputs on refresh steps
			if i < cacheLayers && refreshCache {
				stepCache.Set(i, unified)
			}
		}
	}

	// Extract image tokens only
	unifiedShape := unified.Shape()
	B := unifiedShape[0]
	imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})

	// Final layer
	return m.FinalLayer.Forward(imgOut, temb)
}

// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
	// Create meshgrid and stack
	total := d0 * d1 * d2
	coords := make([]float32, total*3)

	idx := 0
	for i := int32(0); i < d0; i++ {
		for j := int32(0); j < d1; j++ {
			for k := int32(0); k < d2; k++ {
				coords[idx*3+0] = float32(s0 + i)
				coords[idx*3+1] = float32(s1 + j)
				coords[idx*3+2] = float32(s2 + k)
				idx++
			}
		}
	}

	return mlx.NewArray(coords, []int32{1, total, 3})
}

// prepareRoPE3D computes cos/sin for 3-axis RoPE
// positions: [B, L, 3] with (h, w, t) coordinates
// axesDims: [32, 48, 48] - dimensions for each axis
// Returns: cos, sin each [B, L, 1, head_dim/2]
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
	// Compute frequencies for each axis
	// dims = [32, 48, 48], so halves = [16, 24, 24]
	ropeTheta := float32(256.0)

	freqs := make([]*mlx.Array, 3)
	for axis := 0; axis < 3; axis++ {
		half := axesDims[axis] / 2
		f := make([]float32, half)
		for i := int32(0); i < half; i++ {
			f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
		}
		freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
	}

	// Extract position coordinates
	shape := positions.Shape()
	B := shape[0]
	L := shape[1]

	// positions[:, :, 0] -> h positions
	posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
	posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
	posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})

	// Compute args: pos * freqs for each axis
	posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
	posW = mlx.ExpandDims(posW, 3)
	posT = mlx.ExpandDims(posT, 3)

	argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
	argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
	argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]

	// Concatenate: [B, L, 1, 16+24+24=64]
	args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)

	// Compute cos and sin
	return mlx.Cos(args), mlx.Sin(args)
}

// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
	shape := latents.Shape()
	C := shape[1]
	H := shape[2]
	W := shape[3]

	pH := H / patchSize // H_tok
	pW := W / patchSize // W_tok

	// Match Python exactly: reshape treating B=1 as part of contiguous data
	// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
	x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)

	// Python: transpose(1, 2, 3, 5, 4, 6, 0)
	// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
	x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)

	// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
	return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
}

// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
	pH := H / patchSize
	pW := W / patchSize

	// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
	x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)

	// Python: transpose(6, 0, 1, 2, 4, 3, 5)
	// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
	x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)

	// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
	return mlx.Reshape(x, 1, C, H, W)
}