projector.go 1.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
//go:build mlx

package gemma3

import (
	"github.com/ollama/ollama/x/imagegen/mlx"
	"github.com/ollama/ollama/x/imagegen/nn"
)

// MultiModalProjector projects vision features to text embedding space
type MultiModalProjector struct {
	// mm_input_projection_weight: [vision_hidden, text_hidden]
	InputProjection *mlx.Array  `weight:"mm_input_projection_weight"`
	SoftEmbNorm     *nn.RMSNorm `weight:"mm_soft_emb_norm"`

	// Precomputed (1 + weight) for Gemma-style RMSNorm
	SoftEmbNormScaled *mlx.Array `weight:"-"`
}

// Forward projects vision features to text space
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
	// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
	// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
	B := visionFeatures.Shape()[0]
	visionHidden := visionFeatures.Shape()[2]

	// Reshape to [B, 64, 64, hidden]
	gridSize := int32(64) // sqrt(4096)
	pooledSize := int32(16) // 64/4
	h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)

	// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
	h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)

	// Average over pooling dimensions (axes 2 and 4)
	h = mlx.Mean(h, 4, false)
	h = mlx.Mean(h, 2, false)

	// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
	numTokens := pooledSize * pooledSize
	h = mlx.Reshape(h, B, numTokens, visionHidden)

	// Apply Gemma-style RMS norm (use precomputed 1 + weight)
	h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)

	// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
	return mlx.Linear(h, p.InputProjection)
}