vision.go 4.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//go:build mlx

package gemma3

import (
	"math"

	"github.com/ollama/ollama/x/imagegen/mlx"
	"github.com/ollama/ollama/x/imagegen/nn"
)

// VisionConfig holds configuration for the SigLIP vision tower
type VisionConfig struct {
	HiddenSize        int32 `json:"hidden_size"`
	ImageSize         int32 `json:"image_size"`
	IntermediateSize  int32 `json:"intermediate_size"`
	NumAttentionHeads int32 `json:"num_attention_heads"`
	NumHiddenLayers   int32 `json:"num_hidden_layers"`
	PatchSize         int32 `json:"patch_size"`
}

// VisionTower is the SigLIP vision encoder
type VisionTower struct {
	Embeddings    *VisionEmbeddings     `weight:"vision_model.embeddings"`
	Encoder       []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
	PostLayerNorm *nn.LayerNorm         `weight:"vision_model.post_layernorm"`
	Config        *VisionConfig
}

// VisionEmbeddings handles patch and position embeddings
type VisionEmbeddings struct {
	// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
	PatchWeight *mlx.Array    `weight:"patch_embedding.weight"`
	PatchBias   *mlx.Array    `weight:"patch_embedding.bias"`
	PosEmbed    *nn.Embedding `weight:"position_embedding"`
}

// VisionEncoderLayer is a single transformer encoder layer
type VisionEncoderLayer struct {
	LayerNorm1 *nn.LayerNorm     `weight:"layer_norm1"`
	Attention  *VisionAttention  `weight:"self_attn"`
	LayerNorm2 *nn.LayerNorm     `weight:"layer_norm2"`
	MLP        *VisionMLP        `weight:"mlp"`
}

// VisionAttention implements multi-head self-attention
type VisionAttention struct {
	QProj   *nn.Linear `weight:"q_proj"`
	KProj   *nn.Linear `weight:"k_proj"`
	VProj   *nn.Linear `weight:"v_proj"`
	OutProj *nn.Linear `weight:"out_proj"`
}

// VisionMLP is the feed-forward network
type VisionMLP struct {
	FC1 *nn.Linear `weight:"fc1"`
	FC2 *nn.Linear `weight:"fc2"`
}

// Forward runs the vision tower on preprocessed images
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
// Output: [B, num_patches, hidden_size]
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
	// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
	// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
	weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
	h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding

	// Add bias: [O] -> [1, 1, 1, O] for broadcasting
	bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
	h = mlx.Add(h, bias)

	// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
	B := h.Shape()[0]
	gridH, gridW := h.Shape()[1], h.Shape()[2]
	hidden := h.Shape()[3]
	numPatches := gridH * gridW
	h = mlx.Reshape(h, B, numPatches, hidden)

	// Add position embeddings
	posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
	posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
	h = mlx.Add(h, posEmbed)

	// Encoder layers
	headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
	scale := float32(1.0 / math.Sqrt(float64(headDim)))
	for _, layer := range v.Encoder {
		h = layer.Forward(h, v.Config, scale)
	}

	// Final layer norm
	h = v.PostLayerNorm.Forward(h)

	return h
}

// Forward runs a vision encoder layer
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
	// Pre-norm attention
	h := l.LayerNorm1.Forward(x)
	h = l.Attention.Forward(h, cfg, scale)
	x = mlx.Add(x, h)

	// Pre-norm MLP
	h = l.LayerNorm2.Forward(x)
	h = l.MLP.Forward(h)
	return mlx.Add(x, h)
}

// Forward runs multi-head self-attention
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
	B, L := x.Shape()[0], x.Shape()[1]
	headDim := cfg.HiddenSize / cfg.NumAttentionHeads

	q := a.QProj.Forward(x)
	k := a.KProj.Forward(x)
	v := a.VProj.Forward(x)

	// Reshape to [B, num_heads, L, head_dim]
	q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
	k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
	v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)

	// Scaled dot-product attention (no causal mask for vision)
	out := mlx.ScaledDotProductAttention(q, k, v, scale, false)

	// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
	out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)

	return a.OutProj.Forward(out)
}

// Forward runs the MLP with GELU activation
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
	h := mlx.GELU(m.FC1.Forward(x))
	return m.FC2.Forward(h)
}